(模板)多项式乘法对任意数取模

2019-04-13 13:02发布

// 多项式乘法 系数对MOD=1000000007取模, 常数巨大,慎用 // 只要选的K个素数乘积大于MOD*MOD*N,理论上MOD可以任取。 #define MOD 1000000007 #define K 3 const int m[K] = {1004535809, 998244353, 104857601}; #define G 3 int qpow(int x, int k, int P) { int ret = 1; while(k) { if(k & 1) ret = 1LL * ret * x % P; k >>= 1; x = 1LL * x * x % P; } return ret; } struct _NTT { int wn[25], P; void init(int _P) { P = _P; for(int i = 1; i <= 21; ++i) { int t = 1 << i; wn[i] = qpow(G, (P - 1) / t, P); } } void change(int *y, int len) { for(int i = 1, j = len / 2; i < len - 1; ++i) { if(i < j) swap(y[i], y[j]); int k = len / 2; while(j >= k) { j -= k; k /= 2; } j += k; } } void NTT(int *y, int len, int on) { change(y, len); int id = 0; for(int h = 2; h <= len; h <<= 1) { ++id; for(int j = 0; j < len; j += h) { int w = 1; for(int k = j; k < j + h / 2; ++k) { int u = y[k]; int t = 1LL * y[k+h/2] * w % P; y[k] = u + t; if(y[k] >= P) y[k] -= P; y[k+h/2] = u - t + P; if(y[k+h/2] >= P) y[k+h/2] -= P; w = 1LL * w * wn[id] % P; } } } if(on == -1) { for(int i = 1; i < len / 2; ++i) swap(y[i], y[len-i]); int inv = qpow(len, P - 2, P); for(int i = 0; i < len; ++i) y[i] = 1LL * y[i] * inv % P; } } void mul(int A[], int B[], int len) { NTT(A, len, 1); NTT(B, len, 1); for(int i = 0; i < len; ++i) A[i] = 1LL * A[i] * B[i] % P; NTT(A, len, -1); } }ntt[K]; int tmp[N][K], t1[N], t2[N]; int r[K][K]; int CRT(int a[]) { int x[K]; for(int i = 0; i < K; ++i) { x[i] = a[i]; for(int j = 0; j < i; ++j) { int t = (x[i] - x[j]) % m[i]; if(t < 0) t += m[i]; x[i] = 1LL * t * r[j][i] % m[i]; } } int mul = 1, ret = x[0] % MOD; for(int i = 1; i < K; ++i) { mul = 1LL * mul * m[i-1] % MOD; ret += 1LL * x[i] * mul % MOD; if(ret >= MOD) ret -= MOD; } return ret; } void mul(int A[], int B[], int len) { for(int id = 0; id < K; ++id) { for(int i = 0; i < len; ++i) { t1[i] = A[i]; t2[i] = B[i]; } ntt[id].mul(t1, t2, len); for(int i = 0; i < len; ++i) tmp[i][id] = t1[i]; } for(int i = 0; i < len; ++i) { A[i] = CRT(tmp[i]); } } void init() { for(int i = 0; i < K; ++i) { for(int j = 0; j < i; ++j) { r[j][i] = qpow(m[j], m[i] - 2, m[i]); } } for(int i = 0; i < K; ++i) { ntt[i].init(m[i]); } }