「Luogu P5245」【模板】多项式快速幂

给定$n-1$次多项式$F(x)$与整数$k$,求$\bmod x^n$意义下的$(F(x))^k$。

$\texttt{Data Range:}n\leq 10^5,2\leq k\leq 10^{10^5}$

前置知识

多项式基本操作,不会请右转模板区qwq。

题解

首先大力推一波式子

取下对数

再取一下指数

很明显,时间复杂度是$O(n\log n)$,但是如果像我一样写代码会$\texttt{TLE 8-20}$,于是来践行OI界的优良传统卡常数。

前方大图警告





中间有$\texttt{selftest}$的是我的自测,数据比这一题强。

首先看一份比较$\texttt{naive}$的代码(只有$\texttt{35 pts}$,吸氧后在$\texttt{selftest}$上跑还只有$\texttt{45070ms}$):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=3e5+51,MOD=998244353,G=3,INVG=332748118;
ll fd,exponent;
ll f[MAXN],res[MAXN],tmp[MAXN],pinv[MAXN],der[MAXN],texp[MAXN],texp2[MAXN];
ll g[MAXN],rev[MAXN],root[MAXN],invl[MAXN];
inline ll read()
{
register ll num=0,neg=1;
register char ch=getchar();
while(!isdigit(ch)&&ch!='-')
{
ch=getchar();
}
if(ch=='-')
{
neg=-1;
ch=getchar();
}
while(isdigit(ch))
{
num=(num<<3)+(num<<1)+(ch-'0');
ch=getchar();
}
return num*neg;
}
inline ll readm(ll mod)
{
register li num=0,neg=1;
register char ch=getchar();
while(!isdigit(ch)&&ch!='-')
{
ch=getchar();
}
if(ch=='-')
{
neg=-1;
ch=getchar();
}
while(isdigit(ch))
{
num=(((num<<3)+(num<<1))%MOD+(ch-'0'))%MOD;
ch=getchar();
}
return num*neg;
}
inline ll qadd(ll x,ll y,ll mod)
{
return x+y>mod?x+y-mod:x+y;
}
inline ll qmin(ll x,ll y,ll mod)
{
return x-y<0?x-y+mod:x-y;
}
inline ll qpow(ll base,ll exponent,ll mod)
{
li res=1;
while(exponent)
{
if(exponent&1)
{
res=(li)res*base%mod;
}
base=(li)base*base%mod,exponent>>=1;
}
return res;
}
inline void NTT(ll *cp,ll cnt,ll inv,ll mod)
{
ll cur=0,res=0,omg=0;
for(register int i=0;i<cnt;++i)
{
if(i<rev[i])
{
swap(cp[i],cp[rev[i]]);
}
}
for(register int i=2;i<=cnt;i<<=1)
{
cur=i>>1,res=qpow(inv==1?G:INVG,(mod-1)/i,mod);
for(register ll *p=cp;p!=cp+cnt;p+=i)
{
omg=1;
for(register int j=0;j<cur;++j)
{
ll t=(li)omg*p[j+cur]%mod;
p[j+cur]=qmin(p[j],t,mod),p[j]=qadd(p[j],t,mod);
omg=(li)omg*res%mod;
}
}
}
if(inv==-1)
{
ll invl=qpow(cnt,mod-2,mod);
for(register int i=0;i<=cnt;++i)
{
cp[i]=(li)cp[i]*invl%mod;
}
}
}
inline void deriv(ll fd,ll *f,ll *res,ll mod)
{
for(register int i=1;i<fd;++i)
{
res[i-1]=(li)f[i]*i%mod;
}
res[fd-1]=0;
}
inline void integ(ll fd,ll *f,ll *res,ll mod)
{
for(register int i=1;i<fd;++i)
{
res[i]=(li)f[i-1]*qpow(i,mod-2,mod)%mod;
}
res[0]=0;
}
inline void inv(ll fd,ll *f,ll *res,ll mod)
{
if(fd==1)
{
res[0]=qpow(f[0],mod-2,mod);
return;
}
inv((fd+1)>>1,f,res,mod);
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
for(register int i=0;i<cnt;++i)
{
tmp[i]=i<fd?f[i]:0;
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(tmp,cnt,1,mod),NTT(res,cnt,1,mod);
for(register int i=0;i<cnt;++i)
{
res[i]=(li)qmin(2,(li)tmp[i]*res[i]%mod,mod)*res[i]%mod;
}
NTT(res,cnt,-1,mod);
for(register int i=fd;i<cnt;++i)
{
res[i]=0;
}
}
inline void ln(ll fd,ll *f,ll *res,ll mod)
{
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
inv(fd,f,pinv,mod),deriv(fd,f,der,mod);
for(register int i=0;i<cnt;++i)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(pinv,cnt,1,mod),NTT(der,cnt,1,mod);
for(register int i=0;i<cnt;++i)
{
der[i]=(li)der[i]*pinv[i]%mod;
}
NTT(der,cnt,-1,mod),integ(fd,der,res,mod);
for(register int i=0;i<cnt;++i)
{
der[i]=pinv[i]=0;
}
}
inline void exp(ll fd,ll *f,ll *res,ll mod)
{
if(fd==1)
{
res[0]=1;
return;
}
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
exp((fd+1)>>1,f,res,mod),ln(fd,res,texp,mod);
for(register int i=0;i<cnt;++i)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
texp[0]=qmin(f[0]+1,texp[0],mod);
for(register int i=1;i<fd;++i)
{
texp[i]=qmin(f[i],texp[i],mod);
}
NTT(texp,cnt,1,mod),NTT(res,cnt,1,mod);
for(register int i=0;i<cnt;++i)
{
res[i]=(li)res[i]*texp[i]%mod;
}
NTT(res,cnt,-1,mod);
for(register int i=0;i<cnt;++i)
{
texp[i]=0,res[i]=i<fd?res[i]:0;
}
}
int main()
{
fd=read(),exponent=readm(MOD);
for(register int i=0;i<fd;++i)
{
f[i]=read();
}
ln(fd,f,g,MOD);
for(register int i=0;i<fd;++i)
{
g[i]=(li)g[i]*exponent%MOD;
}
exp(fd,g,res,MOD);
for(register int i=0;i<fd;++i)
{
printf("%d ",res[i]);
}
}

首先考虑把一些调用次数多而短的函数给搞掉。

一眼看过去,调用次数最多的是$\texttt{qadd}$和$\texttt{qmin}$,于是考虑吧这两个函数搞掉,把一些副本放到函数里面开$\texttt{static}$,于是就优化到了$\texttt{31078ms}$。

接下来,就可以考虑压缩$\texttt{deriv}$和$\texttt{integ}$了,尽管代码里只有调用一次,但是$\texttt{exp}$中会调用$\log n$次,所以果断搞掉,现在时间是$\texttt{28972ms}$。

接着考虑删掉一些不必要的参数,因为传$\texttt{int}$是$O(32)$的。于是可以考虑删掉$\texttt{mod}$,因为这题只有一个模数,用不着加这样一个参。

经过毒瘤卡常后,一个$\texttt{45000+ms}$的代码被优化成了$\texttt{12000-ms}$,可以见得卡常是个好东西。

附赠$\texttt{selftest}$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=3e5+51,MOD=998244353,G=3,INVG=332748118;
ll fd,exponent;
ll f[MAXN],res[MAXN],g[MAXN],rev[MAXN];
inline ll read()
{
register ll num=0,neg=1;
register char ch=getchar();
while(!isdigit(ch)&&ch!='-')
{
ch=getchar();
}
if(ch=='-')
{
neg=-1;
ch=getchar();
}
while(isdigit(ch))
{
num=(num<<3)+(num<<1)+(ch-'0');
ch=getchar();
}
return num*neg;
}
inline ll readm()
{
register li num=0,neg=1;
register char ch=getchar();
while(!isdigit(ch)&&ch!='-')
{
ch=getchar();
}
if(ch=='-')
{
neg=-1;
ch=getchar();
}
while(isdigit(ch))
{
num=((num<<3)+(num<<1)+(ch-'0'))%MOD;
ch=getchar();
}
return num*neg;
}
inline ll qpow(ll base,ll exponent)
{
li res=1;
while(exponent)
{
if(exponent&1)
{
res=(li)res*base%MOD;
}
base=(li)base*base%MOD,exponent>>=1;
}
return res;
}
inline void NTT(ll *cp,ll cnt,ll inv)
{
ll cur=0,res=0,omg=0;
for(register int i=0;i<cnt;++i)
{
if(i<rev[i])
{
swap(cp[i],cp[rev[i]]);
}
}
for(register int i=2;i<=cnt;i<<=1)
{
cur=i>>1,res=qpow(inv==1?G:INVG,(MOD-1)/i);
for(register ll *p=cp;p!=cp+cnt;p+=i)
{
omg=1;
for(register int j=0;j<cur;++j)
{
ll t=(li)omg*p[j+cur]%MOD,t2=p[j];
p[j+cur]=(t2-t+MOD)%MOD,p[j]=(t2+t)%MOD;
omg=(li)omg*res%MOD;
}
}
}
if(inv==-1)
{
ll invl=qpow(cnt,MOD-2);
for(register int i=0;i<=cnt;++i)
{
cp[i]=(li)cp[i]*invl%MOD;
}
}
}
inline void inv(ll fd,ll *f,ll *res)
{
static ll tmp[MAXN];
if(fd==1)
{
res[0]=qpow(f[0],MOD-2);
return;
}
inv((fd+1)>>1,f,res);
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
for(register int i=0;i<cnt;++i)
{
tmp[i]=i<fd?f[i]:0;
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(tmp,cnt,1),NTT(res,cnt,1);
for(register int i=0;i<cnt;++i)
{
res[i]=(li)(2-(li)tmp[i]*res[i]%MOD+MOD)%MOD*res[i]%MOD;
}
NTT(res,cnt,-1);
for(register int i=fd;i<cnt;++i)
{
res[i]=0;
}
}
inline void ln(ll fd,ll *f,ll *res)
{
static ll pinv[MAXN],der[MAXN];
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
inv(fd,f,pinv);
for(register int i=1;i<fd;++i)
{
der[i-1]=(li)f[i]*i%MOD;
}
der[fd-1]=0;
for(register int i=0;i<cnt;++i)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(pinv,cnt,1),NTT(der,cnt,1);
for(register int i=0;i<cnt;++i)
{
der[i]=(li)der[i]*pinv[i]%MOD;
}
NTT(der,cnt,-1);
for(register int i=1;i<fd;++i)
{
res[i]=(li)der[i-1]*qpow(i,MOD-2)%MOD;
}
res[0]=0;
for(register int i=0;i<cnt;++i)
{
der[i]=pinv[i]=0;
}
}
inline void exp(ll fd,ll *f,ll *res)
{
static ll texp[MAXN];
if(fd==1)
{
res[0]=1;
return;
}
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
exp((fd+1)>>1,f,res),ln(fd,res,texp);
for(register int i=0;i<cnt;++i)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
texp[0]=(f[0]+1-texp[0]+MOD)%MOD;
for(register int i=1;i<fd;++i)
{
texp[i]=(f[i]-texp[i]+MOD)%MOD;
}
NTT(texp,cnt,1),NTT(res,cnt,1);
for(register int i=0;i<cnt;++i)
{
res[i]=(li)res[i]*texp[i]%MOD;
}
NTT(res,cnt,-1);
for(register int i=0;i<cnt;++i)
{
texp[i]=0,res[i]=i<fd?res[i]:0;
}
}
int main()
{
fd=read(),exponent=readm();
for(register int i=0;i<fd;++i)
{
f[i]=read();
}
ln(fd,f,g);
for(register int i=0;i<fd;++i)
{
g[i]=(li)g[i]*exponent%MOD;
}
exp(fd,g,res);
for(register int i=0;i<fd;++i)
{
printf("%d ",res[i]);
}
}