上一次用到原根这个东西,应该是一年之前了吧……
所谓原根,就是指对于某个数字P,满足它的原根g,g的0~P-2次幂在模P的意义下互不相同,或者说g的1~P-1次幂在模P的意义下无不相同。一般来说,原根只能够进行枚举求解,但是由于原根一般较小(2或者3),所以可以接受。
对于本题,要求对于每一个数字ak计算在模P意义下,有多少个ai*aj等于ak。直接做显然是O(N^2)的,不能够满足条件。看到这个题目形式和数据范围,很容易往FFT等方向上想,但是这里的ai的范围是1e9级别,而且这个是乘法,直接FFT也是不行的。由于P满足是质数,所以我们考虑用原根。
我们令g为质数P的原根,那么对于一个数字ai,唯一存在一个数字bi,使得
。那么我们把所有的ai用这种形式表示,于是对于原本的式子ai*aj≡ak(mod P),有
,那么可以有bi+bj=bk。这样问题就从乘法变成了加法,FFT就可以排上用场了。而且你会发现,这里的bi的大小都是在P以内的。这样直接FFT即可,最后特殊处理以下0的情况即可。具体见代码:
#include
#define PI 3.1415926535
#define eps 1e-8
#define mod 998244353
#define LL long long
#define pb push_back
#define lb lower_bound
#define ub upper_bound
#define INF 0x3f3f3f3f
#define sf(x) scanf("%d",&x)
#define sc(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define clr(x,n) memset(x,0,sizeof(x[0])*(n+5))
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
using namespace std;
const int N = 500010;
struct Complex
{
double r,i;
Complex(double real=0.0,double image=0.0)
{
r=real; i=image;
}
Complex operator +(const Complex o){return Complex(r+o.r,i+o.i);}
Complex operator -(const Complex o){return Complex(r-o.r,i-o.i);}
Complex operator *(const Complex o){return Complex(r*o.r-i*o.i,r*o.i+i*o.r);}
} b[N];
namespace FFT
{
int len;
void brc(Complex *y, int l)
{
register int i,j,k;
for( i = 1, j = l / 2; i < l - 1; i++)
{
if (i < j) swap(y[i], y[j]);
k = l / 2; while ( j >= k) j -= k,k /= 2;
if (j < k) j += k;
}
}
void FFT(Complex *y, int len, double on)
{
register int h, j, k;
Complex u, t; brc(y, len);
for(h = 2; h <= len; h <<= 1)
{
Complex wn(cos(on * 2 * PI / h), sin(on * 2 * PI / h));
for(j = 0; j < len; j += h)
{
Complex w(1, 0);
for(k = j; k < j + h / 2; k++)
{
u = y[k]; t = w * y[k + h / 2];
y[k] = u + t; y[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (on<0) for (int i = 0; i < len; i++) y[i].r/=len;
}
void multiply(Complex *A,int lenA)
{
for(len = 1; len < 2*lenA - 1; len <<= 1);
for (int i = lenA; i < len; i++) A[i] = 0;
FFT(A,len ,1 );
for (int i = 0;i < len; i++) A[i] = A[i] * A[i];
FFT(A, len, -1);
}
}
int a[N],id[N],n,P,g;
LL ans[N];
inline bool check(int g)
{
LL tmp=1;
for(int i=1;i=P) puts("0");
else if (a[i]==0) printf("%lld
",t*(n-t)*2+t*t);
else printf("%lld
",ans[id[a[i]]]);
}
return 0;
}