bzoj 2956: 模积和 (反演)

2019-04-14 16:10发布

2956: 模积和

Time Limit: 10 Sec  Memory Limit: 128 MB
Submit: 1276  Solved: 574
[Submit][Status][Discuss]

Description

 求∑∑((n mod i)*(m mod j))其中1<=i<=n,1<=j<=m,i≠j。
  

Input

第一行两个数n,m。

Output

  一个整数表示答案mod 19940417的值

Sample Input


3 4

Sample Output

1

样例说明
  答案为(3 mod 1)*(4 mod 2)+(3 mod 1) * (4 mod 3)+(3 mod 1) * (4 mod 4) + (3 mod 2) * (4 mod 1) + (3 mod 2) * (4 mod 3) + (3 mod 2) * (4 mod 4) + (3 mod 3) * (4 mod 1) + (3 mod 3) * (4 mod 2) + (3 mod 3) * (4 mod 4) = 1

数据规模和约定
  对于100%的数据n,m<=10^9。

HINT

Source

中国国家队清华集训 2012-2013 第一天
[Submit][Status][Discuss]

题解:数论+乘法逆元 刚开始没看到i!=j 这个条件,所以直接将式子化成了sigma(i=1..n)(n mod i)*sigma(i=1..m)(m mod i) 就可以把两部分分开计算,那么这道题就变成了CQOI余数之和 sigma (i=1..n) n mod i =sigma(i=1..n) n-(floor(n/i)*i) 因为floor(n/i)的取值是一段一段的,所以可以在O(sqrt(n))的时间内出解。 然后我们考虑从答案中除去不符合条件的,即sigma(i=1..min(n,m) (n mod i)*(m mod i) =sigma(i=1..min(n,m))(n-floor(n/i)*i)*(m-floor(m/i)*i) =sigma(i=1..min(n,m))n*m-m*floor(n/i)*i-n*floor(m/i)*i-floor(n/i)*floor(m/i)*i*i 可以在O(sqrt(n)+sqrt(m))的时间内出解 需要用到平方和公式sum=n*(n+1)*(2n+1)/6 #include #include #include #include #include #define p 19940417 #define LL long long using namespace std; LL n,m,inv1; LL quickpow(LL num,int x) { LL base=num%p; LL ans=1; while (x) { if (x&1) ans=ans*base%p; x>>=1; base=base*base%p; } return ans; } LL calc(LL n,LL k) { LL i,j; LL ans=0; for (i=1,j=0;i<=k;i=j+1) { if (n/i!=0) j=min(n/(n/i),k); else j=k; ans+=((j-i+1)*(i+j)/2)%p*(n/i)%p; ans%=p; } return ans; } void exgcd(LL a,LL b,LL &x,LL &y) { if (!b) { x=1; y=0; return; } exgcd(b,a%b,x,y); LL t=y; y=x-(a/b)*y; x=t; } LL inv(LL a,LL b) { LL x,y; exgcd(a,b,x,y); return x; } LL sum(LL n1) { return (LL)n1*(n1+1)%p*(2*n1+1)%p*inv1%p; //return n1*(n1+1)*(2*n1+1)/6; } LL calc1(LL k) { LL i,j; LL ans=n*m%p*k%p; for (i=1,j=0;i<=k;i=j+1) { j=min(n/(n/i),m/(m/i)); j=min(j,k); LL t=m*(n/i)+n*(m/i); t=(t%p+p)%p; t=((j-i+1)*(i+j)/2)%p*t; LL t1=sum(j)-sum(i-1); t1=(t1%p+p)%p; ans=ans+((n/i)*(m/i)%p*t1%p-t+p)%p; ans=(ans%p+p)%p; } return ans; } int main() { freopen("a.in","r",stdin); freopen("my.out","w",stdout); scanf("%d%d",&n,&m); inv1=inv(6,p); LL t1=calc(n,n); t1=((LL)n*n-t1)%p; LL t2=calc(m,m); t2=((LL)m*m-t2)%p; LL t3=calc1(min(n,m)); //cout<