「快速傅里叶变换 / 快速数论变换」学习笔记

接触 FFT,是第一次用 Python 写高精度乘法之后知道可以用它做到 \(O(n \log n)\)
然鹅当时的我十分的 simple(并不意味着现在不是),对它望而却步。

FFT

何为 FFT

它可以在 \(O(n \log n)\) 内把一个系数表示的多项式转化为它的点值表示

补充 - 点值表示

\(A(x)\) 为一个 \(n - 1\) 次多项式,那么用 \(n\) 个不同的 \(x\) 带入 \(A\),算出 \(n\)\(y\)
\(n\)\((x,y)\) 可以唯一确定这个多项式

两个多项式相乘称为卷积
系数表示的多项式求卷积的复杂度是 \(O(n^2)\) 的。
但点值表示的多项式的复杂度是 \(O(n)\) 的。

DFT 离散傅里叶变换

傅里叶教我们用特定的 \(x\) 求点值表示——单位根!

补充 - 复数

从前老师教我们 \(\sqrt n\) 有意义当且仅当 \(n \ge 0\)
但是我们也会遇到 \(\sqrt{-1}\) 这种东西。
我们称其为虚数

虚数单位 \(i = \sqrt{-1}\),一个复数 \((x,y) = x + yi\)
其中的 \(x\) 称为实部\(y\) 称为虚部

把复数看成一个向量/点,它所在的平面直角坐标系有一个特殊的名称——复平面。

补充 - 单位根

把单位圆(圆心在原点,半径为 \(1\) 的圆)\(n\) 等分,从 \((1,0)\) 开始逆时针将其编号,第 \(k\) 个记为 \(\omega_n^k\)
显而易见 \(\omega_n^k = (\omega_n^1)^k\),所以 \(\omega_n^1\) 称为 \(n\) 次单位根。

\(\omega_n^k = (\cos \dfrac k n 2 \pi,\sin \dfrac k n 2 \pi)\)
以及两个比较显然的性质: - \(\omega_n^k = \omega_{xn}^{xk}\)。 - \(\omega_n^{k + \frac n 2} = -\omega_n^k\)

IDFT - 离散傅立叶逆变换

把多项式 \(A(x)\) 使用单位根的点值表示再次作为另一个多项式 \(B(x)\) 的系数表示,取 \(\omega_n^0,\omega_n^{-1},\dots,\omega_n^{-n + 1}\) 代入求得 \(B\) 的点值表示。
将其每一位除以 \(n\),就得到了 \(A\) 的系数表示。

\(A(x)\) 的点值表示是 \((b_1,b_2,\dots,b_n)\)\(B(x)\) 的点值表示是 \((c_1,c_2,\dots,c_n)\)
上述结论的证明: \[\begin{align*} c_k & = \sum\limits_{i = 0}^{n - 1} b_i (\omega_n^{-k})^i & = \sum\limits_{i = 0}^{n - 1} (\sum\limits_{j = 0}^{n - 1} ) (\omega_n^{-k})^i & = \sum\limits_{i = 0}^{n - 1} \sum\limits_{j = 0}^{n - 1} (\omega_n^{i - k})^j a_i \end{align*}\]

\(i - k = 0\)\(\sum\limits_{j = 0}^{n - 1} (\omega_n^{i - k})^j = n\)
其余时候根据等比数列求和公式,可知其值为 \(0\)

FFT 快速傅里叶变换

然鹅 DFT 仍然是 \(O(n ^ 2)\) 的……
我们考虑用分治来优化。

\(A(x) = \sum\limits_{i = 0}^{n - 1} a_i x^i\)
\(A_0(x) = \sum\limits_{i = 0}^{\frac n 2 - 1} a_{2i} x^i,A_1(x) = \sum\limits_{i = 0}^{\frac n 2 - 1} a_{2i + 1} x^i\)
于是有 \(A(x) = A_0(x^2) + x A_1(x^2)\)

对于 \(k < \frac n 2\),有 \[\begin{align*} A(\omega_n^k) & = A_0((\omega_n^k)^2) + \omega_n^k A_1((\omega_n^k)^2) \\ & = A_0(\omega_{\frac n 2}^k) + \omega_n^k A_1(\omega_{\frac n 2}^k) \end{align*}\] \[\begin{align*} A(\omega_n^{k + \frac n 2}) & = A_0((\omega_n^{k + \frac n 2})^2) + \omega_n^k A_1((\omega_n^{k + \frac n 2})^2) \\ & = A_0(\omega_{\frac n 2}^k) - \omega_n^k A_1(\omega_{\frac n 2}^k) \end{align*}\]

然后就可以递归地写出一个 FFT 了。

一些优化

非递归

睿智的先人们找到了一种神奇的规律:在 FFT 分治时,最后第 \(x\) 项所在的位置是 \(x\) 二进制翻转后的数。

蝴蝶变换

证明十分严(kong)谨(bu),实际上在代码实现里只是把一个地方换了一下而简化了代码。

参考代码

洛谷 3803.多项式乘法(FFT)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <cstdio>
#include <cmath>
#include <complex>
#include <algorithm>
using namespace std;
const int N = 1 << 21;
const double PI = acos(-1);
typedef complex<double> cp;
int lena,lenb,n = 1,lg;
cp a[N + 5],b[N + 5],omg[N + 5],inv[N + 5];
void fft(cp *a,cp *omg)
{
for(register int i = 0;i < n;++i)
{
int t = 0;
for(register int j = 0;j < lg;++j)
if(i & (1 << j))
t |= (1 << lg - j - 1);
if(i < t)
swap(a[i],a[t]);
}
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
cp t = omg[n / w * j] * a[i + j + m];
a[i + j + m] = a[i + j] - t,a[i + j] += t;
}
}
int main()
{
scanf("%d%d",&lena,&lenb);
++lena,++lenb;
for(;n < lena + lenb;n <<= 1,++lg);
for(register int i = 0;i < n;++i)
inv[i] = conj(omg[i] = cp(cos(2 * PI * i / n),sin(2 * PI * i / n)));
int x;
for(register int i = 0;i < lena;++i)
scanf("%d",&x),a[i].real(x);
for(register int i = 0;i < lenb;++i)
scanf("%d",&x),b[i].real(x);
fft(a,omg),fft(b,omg);
for(register int i = 0;i < n;++i)
a[i] *= b[i];
fft(a,inv);
for(register int i = 0;i < lena + lenb - 1;++i)
printf("%d ",(int)(a[i].real() / n + 0.5));
}


NTT

FFT 到 NTT

傅立叶把单位根的性质应用到了 FFT 中,但是是不是只有单位根有这样的性质呢?
——不,还有原根

补充 - 原根

对于 \(g,P\),如果 \(\forall 1 \le i,j < P,i \ne j,g_i \not\equiv g_j \pmod P\),则称 \(g\)\(P\) 的原根。

NTT 的特点

必须取模,而且模数形如 \(P = 2^k r + 1\)

常用的模数有 \(998244353,1004535809\),其原根均为 \(3\)
对于其他的模数,此处引用一个表,来源见参考文献
其中 \(g\)\(P = 2^k r + 1\) 的原根。

\(P\) \(r\) \(k\) \(g\)
3 1 1 2
5 1 2 2
17 1 4 3
97 3 5 5
193 3 6 5
257 1 8 3
7681 15 9 17
12289 3 12 11
40961 5 13 3
65537 1 16 3
786433 3 18 10
5767169 11 19 3
7340033 7 20 3
23068673 11 21 3
104857601 25 22 3
167772161 5 25 3
469762049 7 26 3
1004535809 479 21 3
2013265921 15 27 31
2281701377 17 27 3
3221225473 3 30 5
75161927681 35 31 3
77309411329 9 33 7
206158430209 3 36 22
2061584302081 15 37 7
2748779069441 5 39 3
6597069766657 3 41 5
39582418599937 9 42 5
79164837199873 9 43 5
263882790666241 15 44 7
1231453023109121 35 45 3
1337006139375617 19 46 3
3799912185593857 27 47 5
4222124650659841 15 48 19
7881299347898369 7 50 6
31525197391593473 7 52 3
180143985094819841 5 55 6
1945555039024054273 27 56 5
4179340454199820289 29 57 3

为什么用原根

NTT 中把所有的 \(\omega_n^k\) 全部替换成了 \(g^{\frac{(P - 1)k}n}\)
为什么可以呢?
因为 FFT 中用到的单位根的性质原根都满足。

参考代码

题目同上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 1 << 21;
const long long mod = 998244353;
const long long G = 3;
const long long Gi = 332748118;
int lena,lenb,n = 1,lg;
long long fpow(long long a,long long b)
{
long long ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = ret * a % mod),a = a * a % mod;
return ret;
}
long long a[N + 5],b[N + 5],omg[N + 5],inv[N + 5];
void ntt(long long *a,long long *omg)
{
for(register int i = 0;i < n;++i)
{
int t = 0;
for(register int j = 0;j < lg;++j)
if(i & (1 << j))
t |= (1 << lg - j - 1);
if(i < t)
swap(a[i],a[t]);
}
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
long long t = omg[n / w * j] * a[i + j + m] % mod;
a[i + j + m] = (a[i + j] - t + mod) % mod,a[i + j] = (a[i + j] + t) % mod;
}
}
int main()
{
scanf("%d%d",&lena,&lenb);
++lena,++lenb;
for(;n < lena + lenb;n <<= 1,++lg);
for(register int i = 0;i < n;++i)
omg[i] = fpow(G,(mod - 1) / n * i),inv[i] = fpow(Gi,(mod - 1) / n * i);
int x;
for(register int i = 0;i < lena;++i)
scanf("%d",&x),a[i] = x;
for(register int i = 0;i < lenb;++i)
scanf("%d",&x),b[i] = x;
ntt(a,omg),ntt(b,omg);
for(register int i = 0;i < n;++i)
a[i] *= b[i];
ntt(a,inv);
long long n_inv = fpow(n,mod - 2);
for(register int i = 0;i < lena + lenb - 1;++i)
printf("%lld ",a[i] * n_inv % mod);
}

参考文献