「CodeForces 438E」The Child and Binary Tree

给定一个元素个数为$n$的集合$c$和一个整数$m$,称一棵二叉树是好的当且仅当这棵二叉树的所有点的权值都属于$c$,规定一棵带点权二叉树的权值是该树中所有点权的总和。对于任意的整数$s$满足$1\leq s\leq m$,求出权值为$s$的好的二叉树的数量,答案对$998244353$取模。

$\texttt{Data Range:}1\leq n,m,c_i\leq 10^5$

链接

题解

一道很好的计数$\texttt{dp}$多项式花样板子题。

首先考虑计数$\texttt{dp}$,其实这个应该不难想。

考虑根节点的点权,枚举一下左子树点权,那么有

这里,$g_i=[i\in c]$。

但是这个算法复杂度不对啊qwq,考虑使用生成函数优化一下。

设$F(x)$为序列$dp$的生成函数(很明显是$\texttt{OGF}$),$G(x)$为序列$g$的生成函数,那么很明显的可以看到,上面的式子是卷积的形式,于是就再搞一搞,就会有

再整理一下

解一下这样一个方程,得到两个根

这个正负号看起来就觉得很不爽,所以考虑一下这个是加号还是减号。

题目保证了$1\leq c_i\leq 10^5$,所以可以知道$G_0=0$,代入一下是加号。

既然$F$是$dp$的生成函数,那么只需要输出$F_1,F_2,\cdots,F_m$就可以啦qwq。

时间复杂度$O(n\log n)$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=3e5+51,MOD=998244353,G=3,INVG=332748118;
ll cnt,ccnt=1,fd;
ll f[MAXN],res[MAXN],g[MAXN],rev[MAXN];
inline ll read()
{
register ll num=0,neg=1;
register char ch=getchar();
while(!isdigit(ch)&&ch!='-')
{
ch=getchar();
}
if(ch=='-')
{
neg=-1;
ch=getchar();
}
while(isdigit(ch))
{
num=(num<<3)+(num<<1)+(ch-'0');
ch=getchar();
}
return num*neg;
}
inline ll qpow(ll base,ll exponent)
{
li res=1;
while(exponent)
{
if(exponent&1)
{
res=(li)res*base%MOD;
}
base=(li)base*base%MOD,exponent>>=1;
}
return res;
}
inline void NTT(ll *cp,ll cnt,ll inv)
{
ll cur=0,res=0,omg=0;
for(register int i=0;i<cnt;++i)
{
if(i<rev[i])
{
swap(cp[i],cp[rev[i]]);
}
}
for(register int i=2;i<=cnt;i<<=1)
{
cur=i>>1,res=qpow(inv==1?G:INVG,(MOD-1)/i);
for(register ll *p=cp;p!=cp+cnt;p+=i)
{
omg=1;
for(register int j=0;j<cur;++j)
{
ll t=(li)omg*p[j+cur]%MOD,t2=p[j];
p[j+cur]=(t2-t+MOD)%MOD,p[j]=(t2+t)%MOD;
omg=(li)omg*res%MOD;
}
}
}
if(inv==-1)
{
ll invl=qpow(cnt,MOD-2);
for(register int i=0;i<=cnt;++i)
{
cp[i]=(li)cp[i]*invl%MOD;
}
}
}
inline void deriv(ll fd,ll *f,ll *res)
{
for(register int i=1;i<fd;++i)
{
res[i-1]=(li)f[i]*i%MOD;
}
res[fd-1]=0;
}
inline void integ(ll fd,ll *f,ll *res)
{
for(register int i=1;i<fd;++i)
{
res[i]=(li)f[i-1]*qpow(i,MOD-2)%MOD;
}
res[0]=0;
}
inline void inv(ll fd,ll *f,ll *res)
{
static ll tmp[MAXN];
if(fd==1)
{
res[0]=qpow(f[0],MOD-2);
return;
}
inv((fd+1)>>1,f,res);
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
for(register int i=0;i<cnt;++i)
{
tmp[i]=i<fd?f[i]:0;
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(tmp,cnt,1),NTT(res,cnt,1);
for(register int i=0;i<cnt;++i)
{
res[i]=(li)(2-(li)tmp[i]*res[i]%MOD+MOD)%MOD*res[i]%MOD;
}
NTT(res,cnt,-1);
for(register int i=fd;i<cnt;++i)
{
res[i]=0;
}
}
inline void ln(ll fd,ll *f,ll *res)
{
static ll pinv[MAXN],der[MAXN];
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
inv(fd,f,pinv),deriv(fd,f,der);
for(register int i=0;i<cnt;++i)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(pinv,cnt,1),NTT(der,cnt,1);
for(register int i=0;i<cnt;++i)
{
der[i]=(li)der[i]*pinv[i]%MOD;
}
NTT(der,cnt,-1),integ(fd,der,res);
for(register int i=0;i<cnt;++i)
{
der[i]=pinv[i]=0;
}
}
inline void exp(ll fd,ll *f,ll *res)
{
static ll texp[MAXN];
if(fd==1)
{
res[0]=1;
return;
}
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
exp((fd+1)>>1,f,res),ln(fd,res,texp);
for(register int i=0;i<cnt;++i)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
texp[0]=(f[0]+1-texp[0]+MOD)%MOD;
for(register int i=1;i<fd;++i)
{
texp[i]=(f[i]-texp[i]+MOD)%MOD;
}
NTT(texp,cnt,1),NTT(res,cnt,1);
for(register int i=0;i<cnt;++i)
{
res[i]=(li)res[i]*texp[i]%MOD;
}
NTT(res,cnt,-1);
for(register int i=0;i<cnt;++i)
{
texp[i]=0,res[i]=i<fd?res[i]:0;
}
}
inline void sqrt(ll fd,ll *f,ll *res)
{
static ll tsqrt[MAXN];
ln(fd,f,tsqrt);
ll cnt=1;
while(cnt<(fd<<1))
{
cnt<<=1;
}
for(register int i=0;i<=cnt;++i)
{
tsqrt[i]=tsqrt[i]&1?(tsqrt[i]+MOD)>>1:tsqrt[i]>>1;
}
exp(fd,tsqrt,res);
for(register int i=0;i<cnt;i++)
{
tsqrt[i]=0;
}
}
int main()
{
cnt=read(),fd=read()+1;
for(register int i=0;i<cnt;++i)
{
++f[read()];
}
while(ccnt<(fd<<1))
{
ccnt<<=1;
}
for(register int i=0;i<ccnt;++i)
{
f[i]=(MOD-((li)4*f[i]%MOD))%MOD;
}
++f[0],sqrt(fd,f,g);
g[0]=(g[0]+1)%MOD,inv(fd,g,res);
for(register int i=1;i<fd;i++)
{
printf("%d ",(res[i]<<1)%MOD);
}
}