模幂运算的几种解决方法

2019-04-14 16:28发布

  【问题】
  计算a**b%c的值。
  其中,"**"代表幂(Python中就是这样表示的);"%"代表取模运算。
  【分析】
  首先由模运算的性质,可以得到下面的公式:
  (a*b) % c = (a%c) * b % c 【公式一】
  将【公式一】继续展开,可得下面的公式:
  (a*b) % c = (a%c) * b % c = (b%c) * (a%c) % c = (a%c) * (b%c) % c 【公式二】
  下面是几种解决方法: 【方法1】利用公式一,使用递归方法计算。
  #include #include // calculate a**b%c int fun(int a, int b, int c) { assert(a && (b>=0) && c); if (0==b) { return 1; } else { return fun(a, b-1, c)*a%c; } } int main() { int a, b, c; while (3==scanf("%d%d%d",&a,&b,&c)) { printf("%d ",fun(a,b,c));// calculate a**b%c } return 0; } 不难看出此方法的缺点是当b越大的时候需要递归的次数就越多,因此就可能会发生stack overflow的错误。所以,此方法非常简单但只适用于b比较小的情况。VS2008下测试,当递归到4624次时发生stack overflow。
  【方法2】利用公式二,使用递归方法计算。
  想办法在方法1的基础上减少递归次数,发现利用公式二可以做到。
  (a*b) % c = (a%c) * (b%c) % c 【公式二】
  当b是奇数的时候,f(b) = f(b-1) * (a % c) % c = a * f(b-1) % c
  当b是偶数的时候,f(b) = f(b/2) * f(b/2) % c #include #include #include // calculate a**b%c int fun(int a, int b, int c) { assert(a && (b>=0) && c); if (0==b) { return 1; } else { return (1==(b&1)) ? a*fun(a, b-1, c)%c : fun(a, b/2, c)*fun(a, b/2, c)%c; } } int main() { int a, b, c; clock_t beg, end; while (3==scanf("%d%d%d",&a,&b,&c)) { beg=clock(); printf("%d ",fun(a,b,c));// calculate a**b%c end=clock(); printf("time used: %.2lf ",double(end-beg)/CLOCKS_PER_SEC); } return 0; } VS2008测试:a=3; b=2147483647;(取int型的最大值) c=1000时,需要约225s可计算出结果为787。
  此方法减少了递归次数,在所能表示的类型里可以防止栈溢出的问题(下面可推证),但是发现运算效率还是太慢,对代码进行检查可以发现问题出在下面这行代码中:
  fun(a, b/2, c)*fun(a, b/2, c)%c;
  这种写法会导致递归调用函数的次数可能达到2*lb(b),虽然递归深度要远小于方法1,但递归次数要远远大于方法1(可以设置一个局部静态变量显示),所以效率很低。解决方法可以先用一个临时变量来保存函数调用的返回值,优化方法如下:
  #include #include #include // calculate a**b%c int fun(int a, int b, int c) { assert(a && (b>=0) && c); if (0==b) { return 1; } else { //return (1==(b&1)) ? a*fun(a, b-1, c)%c : fun(a, b/2, c)*fun(a, b/2, c)%c; if (b&1)// odd { return a*fun(a, b-1, c)%c; } else// even { int tmp=fun(a, b/2, c);// note here! return tmp*tmp%c; } } } int main() { int a, b, c; clock_t beg, end; while (3==scanf("%d%d%d",&a,&b,&c)) { beg=clock(); printf("%d ",fun(a,b,c));// calculate a**b%c end=clock(); printf("time used: %.2lf ",double(end-beg)/CLOCKS_PER_SEC); } return 0; } 用上述同样的数据测试,发现计算所用时间几乎为0.00s,效率得到很大提高。
  思考:虽然这种方法减少了递归次数,但是当b特别大的时候是否还会出现stack overflow的情况。
  测试:将b改为__int64或long long类型,看是否会发生stack overflow并计算时间。
  可以粗略地分析出:在32位平台上用这种方法可以计算出b最大为2^4624时的结果而不会发生栈溢出。 #include #include #include // calculate a**b%c int fun(int a, __int64 b, int c) { assert(a && (b>=0) && c); // 打印出递归次数 /*static int itime=0; printf("%d ",itime++);*/ if (0==b) { return 1; } else { if (b&1)// odd { return a*fun(a, b-1, c)%c; } else// even { int tmp=fun(a, b/2, c);// note here! return tmp*tmp%c; } } } int main() { int a, c; __int64 b; clock_t beg,end; while (3==scanf("%d%I64d%d",&a,&b,&c)) { beg=clock(); printf("%d ",fun(a,b,c));// calculate a**b%c end=clock(); printf("time used: %.2lf ",double(end-beg)/CLOCKS_PER_SEC); } return 0; } 测试:a=3; b=9223372036854775807; c=1000; 输出187,用时几乎为0.00s。
  【方法3】利用公式一,使用递推方法(非递归)计算。 #include #include #include // calculate a**b%c int fun(int a, int b, int c) { assert(a && (b>=0) && c); int res=1; while (b>0) { res*=a; res%=c; --b; } return res; } int main() { int a, b, c; clock_t beg,end; while (3==scanf("%d%d%d",&a,&b,&c)) { beg=clock(); printf("%d ",fun(a,b,c));// calculate a**b%c end=clock(); printf("time used: %.2lf ",double(end-beg)/CLOCKS_PER_SEC); } return 0; } 测试:a=3; b=2147483647;(取int型的最大值) c=1000时,用时约37.02s可计算出结果为787。
  因为此方法需要循环b次,时间复杂度为O(b),效率比较低,但是优点是不会发生stack overflow。
  【方法4】利用公式二,使用递推方法(非递归)计算。 #include #include #include // calculate a**b%c int fun(int a, int b, int c) { assert(a && (b>=0) && c); int res=1; while (b) { if (b&1)// odd { res=res*a%c; --b; } else// even { a=a*a%c; b>>=1; } } return res; } int main() { int a, b, c; clock_t beg, end; while (3==scanf("%d%d%d",&a,&b,&c)) { beg=clock(); printf("%d ",fun(a,b,c));// calculate a**b%c end=clock(); printf("time used: %.2lf ",double(end-beg)/CLOCKS_PER_SEC); } return 0; } 测试:a=3; b=2147483647;(取int型的最大值) c=1000时,用时约0.00s可计算出结果为787。此方法最好的时间复杂度为O(lb(b)),且不会发生stack overflow。
  【方法5】与方法4等价的另一种形式。
  不对b进行减1或除以2,而将b表示为二进制数,通过判断二进制位上是0还是1来计算。这种方法的时间复杂度最小,即O(nbit),n是b的二进制表示的位数。比如:
  b==8(dec.)==1000(bin.),用方法4需要循环2+1+1+0=4次;用方法3也需要循环4次。
  b==7(dec.)==0111(bin.),用方法4需要循环2+2+1=5次;用方法3需要循环4次。
  使用数组实现 : #include #include #include // calculate a**b%c int fun(int a, int b, int c) { assert(a && (b>=0) && c); int res=1, n=0;// n+1 is the length of b in binary int bit_b[32]={0}; while (b) { bit_b[n++]=(b%2); b>>=1; } for (int i=n-1; i>=0; --i) { res=res*res%c; if (bit_b[i]) { res=res*a%c; } } return res; } int main() { int a, b, c; clock_t beg, end; while (3==scanf("%d%d%d",&a,&b,&c)) { beg=clock(); printf("%d ",fun(a,b,c));// calculate a**b%c end=clock(); printf("time used: %.2lfs ",(double)(end-beg)/CLOCKS_PER_SEC); } return 0; } 使用bitset容器实现 : #include #include #include using std::bitset; // calculate a**b%c int fun(int a, int b, int c) { assert(a && (b>=0) && c); int res=1, n;// n+1 is the length of b in binary bitset bitvec_b(b); n=bitvec_b.size()-1; while (0==bitvec_b[n]) --n; for (int i=n; i>=0; --i) { res=res*res%c; if (bitvec_b[i]) { res=res*a%c; } } return res; } int main() { int a, b, c; while (3==scanf("%d%d%d",&a,&b,&c)) { printf("%d ",fun(a,b,c));// calculate a**b%c } return 0; } 参考 :