class="markdown_views prism-atom-one-light">
https://www.cnblogs.com/zzqsblog/p/5665654.html
下面其实是写给自己查阅的,真正要学可以看上面的。
OI里的fft并没有那么神奇,可以简单理解为加速卷积(多项式)的
工具。
设n次多项式A,B,本来需要n^2的时间求其乘积(卷积),使用fft可以加速成n log n.
#n次单位复数根
满足w^n=1的复数。
由复数乘法性质,幅角相加,长度相乘可知,w其实就是将单位圆均分为n份的那n个复数。
记为w1,…wn.
显然
wj=cos(j∗2pi/n)+sin(j∗2pi/n)∗i
#DFT
离散傅里叶变换(求点值表达,下文点值表达特指x取遍n次单位复数根的点值表达)
下文称a次多项式的次数界为a+1。
为了方便,次数界应补足为最近的2的幂。 (高位系数设0)
求一个
次数界为n的多项式,当x取n单位复数根时(w0…wn-1)的n个值。
主要思想是分治,拆分为奇数幂次与偶数幂次,由于单位复数根的特殊性质(平方后减半)
最终式子:
A[i] =A0[i]+A1[i]∗Wi,i<n/2
A[i+n/2]=A0[i]−A1[i]∗Wi,i<n/2
#逆DFT
已知次数界为n的多项式在n次单位复数根下的点值表达,求系数表达。
一发推导后可以发现,只要把n次主根取
w−1,按照原先做dft,
再将最后结果/次数界即可。
最终结果可能有误差,需要加上一个eps.
#蝶形变换
小常数实现fft的方法。
先按二进制翻转(上限取次数界-1),然后从左到右做。
#NTT (数论变换)
模意义下的dft/idft
在模一个费马模数的前提下(
P=k2a+1,比如998244353,g=3),我们可以将n次单位主根单位根替换为
gnP−1,其中g是P的原根。
且满足n<=2^a
假设一个数g是P的原根,那么g^i mod P的结果两两不同,且有 1<g<P,0<i<P,
逆dft中,负指数需要用逆元。注意long long问题
#调试方法
无???
如何验证对错:
对拍
观察fft后虚部是否为0。
对一个数列dft,idft看是否前后一致
#DFT
#include
#include
#include
#include
typedef double db;
typedef long long ll;
#define com complex
using namespace std;
const int N=1e5+10;
const db pi=acos(-1);
int n,h[N*6],M;
com q[N*6],r[N*6];
com a[N*6];
void dft(com *src,int sig) {
for (int i=0; i>1;
for (int i=0; i>n;
for (int i=0; i>1]>>1) + ((i&1) * (M>>1));
//以次数界-1为长度,翻转二进制
dft(q,1); dft(r,1);
for (int i=0; i
#NTT
#include
#include
using namespace std;
const int N=4e5+10,mo=998244353,g=3;
typedef long long ll;
int n,m,M;
int A[N],B[N],h[N];
ll w[N], iw[N];
ll ksm(ll x,ll y) {
if (y==0) return 1;
if (y==1) return x;
ll t=ksm(x,y>>1);
return t*t%mo*ksm(x,y&1)%mo;
}
void ntt(int *a,int sz,int sig) {
for (int i = 1; i < sz; i++)
h[i] = (h[i>>1]>>1) + (i & 1) * (sz >> 1);
for (int i = 0; i <sz; i++)
if (h[i]<i) swap(a[i],a[h[i]]);
for (int m = 1; m < sz; m<<=1) {
int td = M / (m<<1);
for (int i = 0; i < sz; i += (m<<1)) {
for (int j = 0; j < m; j++) {
ll T = a[i+j+m] * (sig == 1 ? w[td * j] : iw[td * j]) % mo;
a[i+j+m] = (a[i+j] - T) % mo;
a[i+j] = (a[i+j] + T) % mo;
}
}
}
}
int main() {
freopen("test.in","r",stdin);
cin>>n>>m;;
for (int i=0; i<=n; i++) scanf("%d",&A[i]);
for (int i=0; i<=m; i++) scanf("%d",&B[i]);
for (M=1; M<=n+m; M<<=1);
for (int i=1; i<M; i++)
h[i]=(h[i>>1]>>1) + (i&1) * (M>>1);
ll ww = ksm(3, (mo - 1) / M);
iw[0] = w[0] = 1;
for (int i = 1; i < M; i++) w[i] = w[i-1] * ww % mo;
ww = ksm(ww, mo - 2);
for (int i = 1; i < M; i++) iw[i] = iw[i-1] * ww % mo;
ntt(A,M,1);
ntt(B,M,1);
for (int i=0; i<M; i++) A[i]=(ll)A[i]*B[i]%mo;
ntt(A,M,-1);
ll cs=ksm(M,mo-2);
for (int i=0; i<=n+m; i++) printf("%lld ",(A[i]*cs%mo+mo)%mo);
}