0%

「清华集训2017」生成树计数

题目链接

前置知识:求数列幂和

给定数列 \(\{a_n\}\),对 \(0\leqslant t \leqslant k\),求 \(\sum\limits_{i=1}^na_i^t\)\(n,t\leqslant 10^5\)

我们写出幂次和对应的生成函数 \(F(x)\) \[ \begin{aligned} F(x)=&\sum\limits_{t=0}\left(\sum\limits_{i=1}^na_i^t\right)x^t\\ =&\sum\limits_{i=1}^n\sum\limits_{t=0}a_i^tx^t\\ =&\sum\limits_{i=1}^n\frac{1}{1-a_ix} \end{aligned} \] 然后我们考虑 \(G(x)=\sum\limits_{i=1}^n\left(\ln(1-a_ix)\right)'\),化简则有 \[ \begin{aligned} G(x)=&\sum\limits_{i=1}^n\left(\ln(1-a_ix)\right)'\\ =&-\sum\limits_{i=1}^n\frac{a_i}{1-a_ix}\\ =&-\sum\limits_{t=0}\sum\limits_{i=1}^na_i^{t+1}x^t \end{aligned} \] 如果我们能求得 \(G(x)\),则很好就能求出 \(F(x)\)。同时发现 \[ \begin{aligned} G(x)=&\sum\limits_{i=1}^n\left(\ln(1-a_ix)\right)'\\ =&\left(\sum\limits_{i=1}^n\ln(1-a_ix)\right)'\\ =&\left(\ln(\prod\limits_{i=1}^n(1-a_ix))\right)' \end{aligned} \] \(\prod\limits_{i=1}^n(1-a_ix)\) 可以通过分治+NTT 计算,然后求个 \(\ln\) 再求导就好了。复杂度 \(O(n\log^2n)\)

然后我们考虑这道题。

显然,我们可以把每一个连通块看成一个点,然后只需要考虑连通块的连接方式,每连接一条边额外乘上连通块大小即可。

我们考虑枚举 prufer 序列,如果一个数在 prufer 序列中出现了 \(d\) 次,则其的度数为 \(d+1\)

那么我们考虑枚举每个点在 prufer 序列中出现的次数 \(d_i\),则我们有 \[ \begin{aligned} ans=&\sum\limits_{\sum d=n-2}(n-2)!\left(\prod_{i=1}^n(d_i+1)^m\right)\left(\sum_{i=1}^n(d_i+1)^m\right)\left(\prod_{i=1}^n\frac{a_i^{d_i+1}}{d_i!}\right)\\ =&(n-2)!\prod_{i=1}^na_i\sum\limits_{\sum d=n-2}\left(\prod_{i=1}^n\frac{a_i^{d_i}}{d_i!}(d_i+1)^m\right)\left(\sum_{i=1}^n(d_i+1)^m\right) \end{aligned} \] 前面两项是定值,我们先不管。我们考虑后面那个东西,它等价于 \[ \sum\limits_{\sum d=n-2}\sum\limits_{i=1}^n\frac{a_i^{d_i}}{d_i!}(d_i+1)^{2m}\prod\limits_{j\not=i}\frac{a_j^{d_j}}{d_j!}(d_j+1)^m \] 我们考虑构造两个多项式 \(A(x)=\sum\limits_{i}\dfrac{(i+1)^{2m}x^i}{i!},B(x)=\sum\limits_{i}\dfrac{(i+1)^mx^i}{i!}\),设答案关于 \(\sum d\) 的生成函数为 \(F(x)\),则有 \[ \begin{aligned} F(x)&=\sum\limits_{i=1}^nA(a_ix)\prod\limits_{j\not=i}B(a_jx)\\ &=\sum\limits_{i=1}^n\frac{A(a_ix)}{B(a_ix)}\prod_{j=1}^n B(a_jx)\\ &=\sum\limits_{i=1}^n\frac{A(a_ix)}{B(a_ix)}\exp\left(\sum\limits_{j=1}^n\ln(B(a_jx))\right) \end{aligned} \] 如果我们求出了 \(\dfrac{A(x)}{B(x)}\)\(\ln(B(x))\),我们只需要每一项的系数乘上 \(\sum\limits_{i=1}^na_i^t\) 即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
namespace polynomials
{
int const p=998244353,g=3;
int const N=(1<<18)+1;
int w[N],iv[N],r[N],last;
typedef vector<int> vec;
int mod(int x){return x>=p?x-p:x;}
int pw(int x,int y)
{
int res=1;
while(y)
{
if(y&1)res=1ll*res*x%p;
x=1ll*x*x%p;
y>>=1;
}
return res;
}
void init(int n)
{
int lim=1;
while(lim<n)lim<<=1;
iv[1]=1;
for(int i=2;i<=lim;i++)iv[i]=mod(p-1ll*(p/i)*iv[p%i]%p);
for(int i=1;i<lim;i<<=1)
{
int wn=pw(g,(p-1)/(i<<1));
for(int j=0,ww=1;j<i;j++,ww=1ll*ww*wn%p)w[i+j]=ww;
}
}
void ntt(vec &f,int n,int op)
{
f.resize(n);
for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)?(n>>1):0);
for(int i=1;i<n;i++)if(i<r[i])swap(f[i],f[r[i]]);
for(int i=1;i<n;i<<=1)
for(int j=0;j<n;j+=i<<1)
for(int k=0;k<i;k++)
{
int x=f[j+k],y=1ll*f[i+j+k]*w[i+k]%p;
f[j+k]=mod(x+y);f[i+j+k]=mod(x-y+p);
}
if(op==-1)
{
reverse(&f[1],&f[n]);
for(int i=0;i<n;i++)f[i]=1ll*f[i]*iv[n]%p;
}
}
void getinv(vec f,vec &g,int n)
{
static vec x;
g.resize(n);
if(n==1){g[0]=pw(f[0],p-2);return;}
getinv(f,g,(n+1)>>1);
int lim=1;
while(lim<(n<<1))lim<<=1;
x.resize(lim);
for(int i=0;i<n;i++)x[i]=f[i];
for(int i=n;i<lim;i++)x[i]=0;
g.resize(lim);
ntt(x,lim,1),ntt(g,lim,1);
for(int i=0;i<lim;i++)g[i]=1ll*g[i]*(2-1ll*g[i]*x[i]%p+p)%p;
ntt(g,lim,-1);
g.resize(n);
}
void getln(vec f,vec &g,int n)
{
static vec x;
getinv(f,x,n);
for(int i=0;i<n-1;i++)f[i]=1ll*f[i+1]*(i+1)%p;
f[n-1]=0;
int lim=1;
while(lim<((n<<1)-1))lim<<=1;
ntt(f,lim,1),ntt(x,lim,1);
for(int i=0;i<lim;i++)x[i]=1ll*x[i]*f[i]%p;
ntt(x,lim,-1);
g.resize(n);
g[0]=0;
for(int i=1;i<n;i++)g[i]=1ll*x[i-1]*iv[i]%p;
}
void getln(vec f,vec &g,vec &h,int n)
{
static vec x;
getinv(f,x,n);h=x;
for(int i=0;i<n-1;i++)f[i]=1ll*f[i+1]*(i+1)%p;
f[n-1]=0;
int lim=1;
while(lim<((n<<1)-1))lim<<=1;
ntt(f,lim,1),ntt(x,lim,1);
for(int i=0;i<lim;i++)x[i]=1ll*x[i]*f[i]%p;
ntt(x,lim,-1);
g.resize(n);
g[0]=0;
for(int i=1;i<n;i++)g[i]=1ll*x[i-1]*iv[i]%p;
}
void getexp(vec f,vec &g,int n)
{
static vec x;
g.resize(n);
if(n==1){g[0]=1;return;}
int m,lim=1;
getexp(f,g,m=((n+1)>>1));
while(lim<(n<<1))lim<<=1;
g.resize(lim);
getln(g,x,n);
x.resize(lim);
for(int i=0;i<n;i++)x[i]=mod(f[i]-x[i]+p);
for(int i=n;i<lim;i++)x[i]=0;
x[0]=mod(x[0]+1);
ntt(g,lim,1),ntt(x,lim,1);
for(int i=0;i<lim;i++)g[i]=1ll*g[i]*x[i]%p;
ntt(g,lim,-1);
g.resize(n);
}
}
using namespace polynomials;
int n,m,a[30005];
vec solve(int l,int r)
{
if(l==r){vec res;res.pb(1);res.pb(p-a[l]);return res;}
int mid=(l+r)>>1,lim=1;
vec x=solve(l,mid),y=solve(mid+1,r);
while(lim<(r-l+2))lim<<=1;
ntt(x,lim,1),ntt(y,lim,1);
for(int i=0;i<lim;i++)x[i]=1ll*x[i]*y[i]%p;
ntt(x,lim,-1);x.resize(r-l+2);
return x;
}
int fac[30005],inv[30005];
int main()
{
int prod=1;
init(131071);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&a[i]),prod=1ll*prod*a[i]%p;
vec t=solve(1,n+1),res;
getln(t,res,n+1);
for(int i=0;i<n;i++)res[i]=1ll*res[i+1]*(i+1)%p;
for(int i=0;i<n;i++)
if(res[i])res[i]=p-res[i];
for(int i=n-1;i>=1;i--)res[i]=res[i-1];
res[0]=n;
vec a,b,c,d;
a.resize(n),b.resize(n);
fac[0]=inv[0]=1;
for(int i=1;i<n;i++)fac[i]=1ll*fac[i-1]*i%p;
inv[n-1]=pw(fac[n-1],p-2);
for(int i=n-2;i>=1;i--)inv[i]=1ll*inv[i+1]*(i+1)%p;
for(int i=0;i<n;i++)
{
int t=pw(i+1,m);
b[i]=1ll*t*inv[i]%p;
a[i]=1ll*b[i]*t%p;
}
getln(b,d,c,n);
int lim=1;
while(lim<(n<<1))lim<<=1;
ntt(a,lim,1),ntt(c,lim,1);
for(int i=0;i<lim;i++)a[i]=1ll*a[i]*c[i]%p;
ntt(a,lim,-1),a.resize(n);
for(int i=0;i<n;i++)a[i]=1ll*a[i]*res[i]%p,d[i]=1ll*d[i]*res[i]%p;
c.clear();
getexp(d,c,n);
ntt(a,lim,1),ntt(c,lim,1);
for(int i=0;i<lim;i++)a[i]=1ll*a[i]*c[i]%p;
ntt(a,lim,-1);
printf("%lld",1ll*a[n-2]*fac[n-2]%p*prod%p);
return 0;
}