Fast Fourier Transform
Fast fourier transform can be applied on multiplication of polynomial functions, giving O( N lgN ) time complexity. One of the most common usage is for fast multiplication on big numbers. It exploits the property of the complex roots of unity to make a divide and conquer strategy possible. In the end of this blog I will put some papers with mathematical proof of FFT and some brief summary I have made.
Here's a problem you can try using FFT
C: 高速フーリエ変換 - AtCoder Typical Contest 001 | AtCoder
See this slide for rigorous mathematical proof and further understanding.
This post is very easy to understand
Tutorial on FFT — The tough made simple. ( Part 1 ) - Codeforces
#include <bits/stdc++.h> using namespace std; double PI = acos( -1.0 ); typedef complex<double> Z; const Z I( 0, 1 ); // i = 0 + 1 * i vector<Z> fft(int n, const vector<Z> &a, bool inv){ if( n == 1 ) return a; Z theta = 2 * PI * I / Z( n, 0 ); if( inv ) theta = -theta; Z omega = Z( 1, 0 ), deltaOmega = exp( theta ); vector<Z> aeven( n / 2 ), aodd( n / 2 ); for(int i = 0; i < n / 2; ++i){ aeven[i] = a[i * 2]; aodd[i] = a[i * 2 + 1]; } aeven = fft( n / 2, aeven, inv ); aodd = fft( n / 2, aodd, inv ); vector<Z> y( n ); for(int k = 0; k < n / 2; ++k){ y[k] = aeven[k] + omega * aodd[k]; y[k + n / 2] = aeven[k] - omega * aodd[k]; omega *= deltaOmega; } return y; } int main(){ ios::sync_with_stdio(false); int n; cin >> n; ++n; vector<int> g(n), h(n); for(int i = 1; i < n; ++i) cin >> g[i] >> h[i]; int nn = 1; while( nn < 2 * n ) nn <<= 1; vector<Z> yg(nn), yh(nn); for(int i = 0; i < nn; ++i){ if( i < n ) yg[i] = g[i], yh[i] = h[i]; else yg[i] = yh[i] = 0; } yg = fft( nn, yg, false ); yh = fft( nn, yh, false ); vector<Z> yf(nn); for(int i = 0; i < nn; ++i) yf[i] = yg[i] * yh[i]; yf = fft( nn, yf, true ); --n; for(int i = 1; i <= 2 * n; ++i) cout << int( yf[i].real() / nn + 0.5 ) << endl; return 0; }
And here is a problem for fast multiplication on SPOJ SPOJ.com - Problem VFMUL
Don't know why but will get RE when submitting with C++14, but will AC nicely with C++5.1. Anyways, I will take a look at it later.
Note that long double is required ( as much as I know ), for such high digits.
#include <bits/stdc++.h> using namespace std; typedef long double daburu; // key difference 31 and the advanced - 235 const daburu PI = acos( -1.0 ); typedef complex<daburu> Z; const Z I( 0, 1 ); typedef long long ll; const int MAXN = ( 3e5 + 2 ) * 8; const int B = 4; // can do for at most 5, but I don't get why 5 works const ll BASE = 1e4; vector<Z> fft(int n, const vector<Z> &a, bool inv){ if( n == 1 ) return a; daburu theta = 2 * PI / n; if( inv ) theta = -theta; Z omega = Z( 1, 0 ), deltaOmega = exp( I * theta ); vector<Z> aeven( n / 2 ), aodd( n / 2 ); for(int i = 0; i < n / 2; ++i){ aeven[i] = a[i * 2]; aodd[i] = a[i * 2 + 1]; } aeven = fft( n / 2, aeven, inv ); aodd = fft( n / 2, aodd, inv ); vector<Z> y( n ); for(int i = 0; i < n / 2; ++i){ y[i] = aeven[i] + omega * aodd[i]; y[i + n / 2] = aeven[i] - omega * aodd[i]; omega *= deltaOmega; } return y; } char l1[MAXN], l2[MAXN]; int g[MAXN], h[MAXN]; vector<Z> yg(MAXN), yh(MAXN), yf(MAXN); ll ans[MAXN]; int main(){ int T; scanf("%d", &T); while( T-- ){ scanf("%s %s", l1, l2); int sn = strlen( l1 ), sm = strlen( l2 ); int n = 0, m = 0; for(int i = sn - 1; i >= 0; i -= B){ int x = 0, k = 1; for(int j = i; j > i - B; --j){ if( j < 0 ) break; x += ( l1[j] - '0' ) * k; k *= 10; } g[n++] = x; } for(int i = sm - 1; i >= 0; i -= B){ int x = 0, k = 1; for(int j = i; j > i - B; --j){ if( j < 0 ) break; x += ( l2[j] - '0' ) * k; k *= 10; } h[m++] = x; } int nn = 1; while( nn < max( n, m ) * 2 ) nn <<= 1; for(int i = 0; i < nn; ++i){ if( i < n ) yg[i] = g[i]; else yg[i] = 0; if( i < m ) yh[i] = h[i]; else yh[i] = 0; } yg = fft( nn, yg, false ); yh = fft( nn, yh, false ); for(int i = 0; i < nn; ++i) yf[i] = yg[i] * yh[i]; yf = fft( nn, yf, true ); for(int i = 0; i < nn; ++i) ans[i] = (ll)( yf[i].real() / nn + 0.5 ); for(int i = 0; i < nn; ++i){ ll x = ans[i]; ans[i] = 0; for(int j = 0; x; x /= BASE, ++j) ans[i + j] += x % BASE; } bool firstPrint = true; int ni = nn - 1; while( ni > 0 && ans[ni] == 0 ) --ni; for(int i = ni; i >= 0; --i){ if( firstPrint ){ firstPrint = false; printf("%d", (int)ans[i]); } else printf("%04d", (int)ans[i]); } puts(""); } return 0; }
Kind of a general template here, verified by http://zerojudge.tw/ShowProblem ZOJ a577
#include <bits/stdc++.h> using namespace std; typedef long double daburu; const daburu PI = acos( -1.0 ); typedef complex<daburu> Z; const Z I( 0, 1 ); typedef long long ll; const int MAXN = ( (1 << 17) + 1 ) * 8; const int B = 4; const ll BASE = 1e4; void fft(int n, vector<Z> &a, bool inv){ if( n == 1 ) return; daburu theta = 2 * PI / n; if( inv ) theta = -theta; Z omega = Z( 1, 0 ), deltaOmega = exp( I * theta ); vector<Z> aeven( n / 2 ), aodd( n / 2 ); for(int i = 0; i < n / 2; ++i){ aeven[i] = a[i * 2]; aodd[i] = a[i * 2 + 1]; } fft( n / 2, aeven, inv ); fft( n / 2, aodd, inv ); for(int i = 0; i < n / 2; ++i){ a[i] = aeven[i] + omega * aodd[i]; a[i + n / 2] = aeven[i] - omega * aodd[i]; omega *= deltaOmega; } } void print(int n, const vector<Z> &a){ vector<ll> ans(MAXN); for(int i = 0; i < n; ++i) ans[i] = (ll)( a[i].real() + 0.5 ); for(int i = 0; i < n; ++i){ ll d = ans[i]; ans[i] = 0; for(int j = 0; d; ++j, d /= BASE) ans[i + j] += d % BASE; } bool firstPrint = true; int ni = a.size() - 1; while( ni > 0 && ans[ni] == 0 ) --ni; for(int i = ni; i >= 0; --i){ if( firstPrint ){ firstPrint = false; printf("%d", int(ans[i])); } else printf("%04d", int(ans[i])); } } char l1[MAXN], l2[MAXN]; int g[MAXN], h[MAXN]; vector<Z> yg(MAXN), yh(MAXN), yf(MAXN); int main(){ while( scanf("%s %s", l1, l2) == 2 ){ int sn = strlen( l1 ), sm = strlen( l2 ); int n = 0, m = 0; for(int i = sn - 1; i >= 0; i -= B){ int x = 0, k = 1; for(int j = i; j > i - B; --j){ if( j < 0 ) break; x += ( l1[j] - '0' ) * k; k *= 10; } g[n++] = x; } for(int i = sm - 1; i >= 0; i -= B){ int x = 0, k = 1; for(int j = i; j > i - B; --j){ if( j < 0 ) break; x += ( l2[j] - '0' ) * k; k *= 10; } h[m++] = x; } int nn = 1; while( nn < max( n, m ) * 2 ) nn <<= 1; for(int i = 0; i < nn; ++i){ if( i < n ) yg[i] = g[i]; else yg[i] = 0; if( i < m ) yh[i] = h[i]; else yh[i] = 0; } fft( nn, yg, false ); fft( nn, yh, false ); for(int i = 0; i < nn; ++i) yf[i] = yg[i] * yh[i]; fft( nn, yf, true ); for(int i = 0; i < nn; ++i) yf[i] /= nn; print( nn, yf ); puts(""); } return 0; }