【模板】任意模数NTT(中国剩余定理版,O(1)long long快速乘)

2019-04-14 21:09发布

任意模的NTT,即题目给定要求的取模的数MOD!=a2^k+1的形式,或者2^k小于需要的数值 如果假定NTT作用长度为len,系数的值不大于x,则相乘后系数不大于len*x^2. 如果我们取合适的多个模数p[](他们有相同的原根),使p_1p_2p_3>len*x^2,同时我们得到分别以p_1,p_2,p_3为模的NTT作用系数C[1][],C[2][],C[3][],我们可以得到实际系数x满足: \ x[i] equiv C[1][i] (mod p_1)\ x[i] equiv C[2][i] (mod p_2)\ x[i] equiv C[3][i] (mod p_3),由中国剩余定理通解x[i] = X_0+kp_1p_2p_3,而x[i] leq len*x^x < p_1p_2p_3  => x[i]=X_0 但这里由于p_i的选取,p_1p_2p_3会爆long long,故可以先求解x[i] equiv x_0 (mod p_1p_2), x[i]=x_0+k_1p_1p_2=C[3][i]+k_3p_3k_1p_1p_2 equiv C[3][i]-x_0(mod p_3),求出k_1后即可得到X_0 这里有个神奇的O(1)long long乘法,根据的是A\%B=A-B*lfloorfrac{A}{B}
floor,以及溢出后减法在模意义下的等价(不太确定 (..•˘_˘•..)) LL multi(LL a, LL b, LL mod){ a %= mod, b %= mod; return ((a * b - (LL)((LL)((long double)a / mod * b + 1e-3) * mod)) % mod + mod) % mod; } #include #include #include #include #include using namespace std; using LL = long long; const int MAXN = 4e5 + 5; const LL p[] = {0, 469762049ll, 998244353ll, 1004535809ll}, g = 3, MX = 469762049ll * 998244353ll; LL N, M, MOD; LL A[MAXN], B[MAXN], C[4][MAXN], F[MAXN], G[MAXN], ans[MAXN]; LL qpow(LL, int, LL); void DNT(int len, LL *a, int type, LL mod); void getC(int, int); LL multi(LL, LL, LL); int main(){ ios::sync_with_stdio(false); cin >> N >> M >> MOD; int i; for(i = 0; i <= N; i++) cin >> A[i]; for(i = 0; i <= M; i++) cin >> B[i]; int len = 1; while(len <= M + N) len <<= 1; getC(1, len), getC(2, len), getC(3, len); for(i = 0; i <= N + M; i++){ LL x, k1; x = (multi(C[1][i] * p[2] % MX, qpow(p[2] % p[1], p[1] - 2, p[1]), MX) + multi(C[2][i] * p[1] % MX, qpow(p[1] % p[2], p[2] - 2, p[2]), MX)) % MX; k1 = (multi((C[3][i] % p[3] - x % p[3] + p[3]) % p[3], qpow(MX % p[3], p[3] - 2, p[3]), p[3])); ans[i] = ((k1 % MOD) * (MX % MOD) + x % MOD) % MOD; cout << ans[i] << " "; } return 0; } void getC(int I, int len){ int i; memset(F, 0, sizeof(F)), memset(G, 0, sizeof(G)); for(i = 0; i <= N; i++) F[i] = A[i]; for(i = 0; i <= M; i++) G[i] = B[i]; DNT(len, F, 1, p[I]), DNT(len, G, 1, p[I]); for(i = 0; i <= len; i++) C[I][i] = F[i] * G[i] % p[I]; DNT(len, C[I], -1, p[I]); } void bit_reverse(int, LL *); void DNT(int len, LL *a, int type, LL mod){ bit_reverse(len, a); int i, j, l; for(l = 2; l <= len; l <<= 1){ int mid = l >> 1; LL wn = qpow(g, (mod - 1) / l, mod); if(type == -1) wn = qpow(wn, mod - 2, mod); for(i = 0; i < len; i += l){ LL w = 1; for(j = 0; j < mid; j++, w = w * wn % mod){ LL x = a[i + j], y = w * a[i + j + mid] % mod; a[i + j] = (x + y) % mod; a[i + j + mid] = (x - y + mod) % mod; } } } if(type == -1){ int inv = qpow(len, mod - 2, mod); for(i = 0; i <= len; i++) a[i] = a[i] * inv % mod; } } void bit_reverse(int len, LL *a){ int i, j, k; for(i = 0, j = 0; i < len; i++){ if(i > j) swap(a[i], a[j]); for(k = len >> 1; (j & k); j ^= k, k >>= 1); j ^= k; } } LL qpow(LL x, int n, LL mod){ LL res = 1; while(n){ if(n & 1) res = res * x % mod; x = x * x % mod; n >>= 1; } return res; } LL multi(LL a, LL b, LL mod){ a %= mod, b %= mod; return ((a * b - (LL)((LL)((long double)a / mod * b + 1e-3) * mod)) % mod + mod) % mod; }