DSP

FFT简单入门(板子)

2019-07-13 15:33发布

class="markdown_views prism-atom-one-light"> https://www.cnblogs.com/zzqsblog/p/5665654.html
下面其实是写给自己查阅的,真正要学可以看上面的。 OI里的fft并没有那么神奇,可以简单理解为加速卷积(多项式)的工具。 设n次多项式A,B,本来需要n^2的时间求其乘积(卷积),使用fft可以加速成n log n. #n次单位复数根
满足w^n=1的复数。
由复数乘法性质,幅角相加,长度相乘可知,w其实就是将单位圆均分为n份的那n个复数。
记为w1,…wn.
显然wj=cos(j2pi/n)+sin(j2pi/n)iwj=cos(j*2pi/n) + sin(j*2pi/n)*i #DFT
离散傅里叶变换(求点值表达,下文点值表达特指x取遍n次单位复数根的点值表达)
下文称a次多项式的次数界为a+1。
为了方便,次数界应补足为最近的2的幂。 (高位系数设0)
求一个次数界为n的多项式,当x取n单位复数根时(w0…wn-1)的n个值。
主要思想是分治,拆分为奇数幂次与偶数幂次,由于单位复数根的特殊性质(平方后减半)
最终式子:
A[i]           =A0[i]+A1[i]Wii<n/2A[i]~~~~~~~~~~~=A0[i] + A1[i] * W^i,i<n/2
A[i+n/2]=A0[i]A1[i]Wii<n/2A[i+n/2]=A0[i] -A1[i] * W^i,i<n/2 #逆DFT
已知次数界为n的多项式在n次单位复数根下的点值表达,求系数表达。
一发推导后可以发现,只要把n次主根取w1w^{-1},按照原先做dft,再将最后结果/次数界即可。 最终结果可能有误差,需要加上一个eps. #蝶形变换
小常数实现fft的方法。
这里写图片描述 先按二进制翻转(上限取次数界-1),然后从左到右做。 #NTT (数论变换)
模意义下的dft/idft
在模一个费马模数的前提下(P=k2a+1P=k2^a+1,比如998244353,g=3),我们可以将n次单位主根单位根替换为gP1ng^{frac {P-1} {n}},其中g是P的原根。
且满足n<=2^a
假设一个数g是P的原根,那么g^i mod P的结果两两不同,且有 1<g<P0<i<P1<g<P,0<i<P
逆dft中,负指数需要用逆元。注意long long问题 #调试方法
无???
如何验证对错:
对拍
观察fft后虚部是否为0。
对一个数列dft,idft看是否前后一致 #DFT #include #include #include #include typedef double db; typedef long long ll; #define com complex using namespace std; const int N=1e5+10; const db pi=acos(-1); int n,h[N*6],M; com q[N*6],r[N*6]; com a[N*6]; void dft(com *src,int sig) { for (int i=0; i>1; for (int i=0; i>n; for (int i=0; i>1]>>1) + ((i&1) * (M>>1)); //以次数界-1为长度,翻转二进制 dft(q,1); dft(r,1); for (int i=0; i #NTT #include #include using namespace std; const int N=4e5+10,mo=998244353,g=3; typedef long long ll; int n,m,M; int A[N],B[N],h[N]; ll w[N], iw[N]; ll ksm(ll x,ll y) { if (y==0) return 1; if (y==1) return x; ll t=ksm(x,y>>1); return t*t%mo*ksm(x,y&1)%mo; } void ntt(int *a,int sz,int sig) { for (int i = 1; i < sz; i++) h[i] = (h[i>>1]>>1) + (i & 1) * (sz >> 1); for (int i = 0; i <sz; i++) if (h[i]<i) swap(a[i],a[h[i]]); for (int m = 1; m < sz; m<<=1) { int td = M / (m<<1); for (int i = 0; i < sz; i += (m<<1)) { for (int j = 0; j < m; j++) { ll T = a[i+j+m] * (sig == 1 ? w[td * j] : iw[td * j]) % mo; a[i+j+m] = (a[i+j] - T) % mo; a[i+j] = (a[i+j] + T) % mo; } } } } int main() { freopen("test.in","r",stdin); cin>>n>>m;; for (int i=0; i<=n; i++) scanf("%d",&A[i]); for (int i=0; i<=m; i++) scanf("%d",&B[i]); for (M=1; M<=n+m; M<<=1); for (int i=1; i<M; i++) h[i]=(h[i>>1]>>1) + (i&1) * (M>>1); ll ww = ksm(3, (mo - 1) / M); iw[0] = w[0] = 1; for (int i = 1; i < M; i++) w[i] = w[i-1] * ww % mo; ww = ksm(ww, mo - 2); for (int i = 1; i < M; i++) iw[i] = iw[i-1] * ww % mo; ntt(A,M,1); ntt(B,M,1); for (int i=0; i<M; i++) A[i]=(ll)A[i]*B[i]%mo; ntt(A,M,-1); ll cs=ksm(M,mo-2); for (int i=0; i<=n+m; i++) printf("%lld ",(A[i]*cs%mo+mo)%mo); }