「快速傅里叶变换 / 快速数论变换」学习笔记

接触 FFT,是第一次用 Python 写高精度乘法之后知道可以用它做到 O(nlogn)O(n \log n)
然鹅当时的我十分的 simple(并不意味着现在不是),对它望而却步。

FFT

何为 FFT

它可以在 O(nlogn)O(n \log n) 内把一个系数表示的多项式转化为它的点值表示

补充 - 点值表示

A(x)A(x) 为一个 n1n - 1 次多项式,那么用 nn 个不同的 xx 带入 AA,算出 nnyy
nn(x,y)(x,y) 可以唯一确定这个多项式

两个多项式相乘称为卷积
系数表示的多项式求卷积的复杂度是 O(n2)O(n^2) 的。
但点值表示的多项式的复杂度是 O(n)O(n) 的。

DFT 离散傅里叶变换

傅里叶教我们用特定的 xx 求点值表示——单位根!

补充 - 复数

从前老师教我们 n\sqrt n 有意义当且仅当 n0n \ge 0
但是我们也会遇到 1\sqrt{-1} 这种东西。
我们称其为虚数

虚数单位 i=1i = \sqrt{-1},一个复数 (x,y)=x+yi(x,y) = x + yi
其中的 xx 称为实部yy 称为虚部

把复数看成一个向量/点,它所在的平面直角坐标系有一个特殊的名称——复平面。

补充 - 单位根

把单位圆(圆心在原点,半径为 11 的圆)nn 等分,从 (1,0)(1,0) 开始逆时针将其编号,第 kk 个记为 ωnk\omega_n^k
显而易见 ωnk=(ωn1)k\omega_n^k = (\omega_n^1)^k,所以 ωn1\omega_n^1 称为 nn 次单位根。

ωnk=(coskn2π,sinkn2π)\omega_n^k = (\cos \dfrac k n 2 \pi,\sin \dfrac k n 2 \pi)
以及两个比较显然的性质:

  • ωnk=ωxnxk\omega_n^k = \omega_{xn}^{xk}
  • ωnk+n2=ωnk\omega_n^{k + \frac n 2} = -\omega_n^k

IDFT - 离散傅立叶逆变换

把多项式 A(x)A(x) 使用单位根的点值表示再次作为另一个多项式 B(x)B(x) 的系数表示,取 代入求得 BB 的点值表示。
将其每一位除以 nn,就得到了 AA 的系数表示。

A(x)A(x) 的点值表示是 B(x)B(x) 的点值表示是
上述结论的证明:

ik=0i - k = 0j=0n1(ωnik)j=n\sum\limits_{j = 0}^{n - 1} (\omega_n^{i - k})^j = n
其余时候根据等比数列求和公式,可知其值为 00

FFT 快速傅里叶变换

然鹅 DFT 仍然是 O(n2)O(n ^ 2) 的……
我们考虑用分治来优化。

A(x)=i=0n1aixiA(x) = \sum\limits_{i = 0}^{n - 1} a_i x^i
A0(x)=i=0n21a2ixi,A1(x)=i=0n21a2i+1xiA_0(x) = \sum\limits_{i = 0}^{\frac n 2 - 1} a_{2i} x^i,A_1(x) = \sum\limits_{i = 0}^{\frac n 2 - 1} a_{2i + 1} x^i
于是有 A(x)=A0(x2)+xA1(x2)A(x) = A_0(x^2) + x A_1(x^2)

对于 k<n2k < \frac n 2,有

然后就可以递归地写出一个 FFT 了。

一些优化

非递归

睿智的先人们找到了一种神奇的规律:在 FFT 分治时,最后第 xx 项所在的位置是 xx 二进制翻转后的数。

蝴蝶变换

证明十分严(kong)谨(bu),实际上在代码实现里只是把一个地方换了一下而简化了代码。

参考代码

洛谷 3803.多项式乘法(FFT)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <cstdio>
#include <cmath>
#include <complex>
#include <algorithm>
using namespace std;
const int N = 1 << 21;
const double PI = acos(-1);
typedef complex<double> cp;
int lena,lenb,n = 1,lg;
cp a[N + 5],b[N + 5],omg[N + 5],inv[N + 5];
void fft(cp *a,cp *omg)
{
for(register int i = 0;i < n;++i)
{
int t = 0;
for(register int j = 0;j < lg;++j)
if(i & (1 << j))
t |= (1 << lg - j - 1);
if(i < t)
swap(a[i],a[t]);
}
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
cp t = omg[n / w * j] * a[i + j + m];
a[i + j + m] = a[i + j] - t,a[i + j] += t;
}
}
int main()
{
scanf("%d%d",&lena,&lenb);
++lena,++lenb;
for(;n < lena + lenb;n <<= 1,++lg);
for(register int i = 0;i < n;++i)
inv[i] = conj(omg[i] = cp(cos(2 * PI * i / n),sin(2 * PI * i / n)));
int x;
for(register int i = 0;i < lena;++i)
scanf("%d",&x),a[i].real(x);
for(register int i = 0;i < lenb;++i)
scanf("%d",&x),b[i].real(x);
fft(a,omg),fft(b,omg);
for(register int i = 0;i < n;++i)
a[i] *= b[i];
fft(a,inv);
for(register int i = 0;i < lena + lenb - 1;++i)
printf("%d ",(int)(a[i].real() / n + 0.5));
}


NTT

FFT 到 NTT

傅立叶把单位根的性质应用到了 FFT 中,但是是不是只有单位根有这样的性质呢?
——不,还有原根

补充 - 原根

对于 g,Pg,P,如果 ,则称 ggPP 的原根。

NTT 的特点

必须取模,而且模数形如 P=2kr+1P = 2^k r + 1

常用的模数有 998244353,1004535809998244353,1004535809,其原根均为 33
对于其他的模数,此处引用一个表,来源见参考文献
其中 ggP=2kr+1P = 2^k r + 1 的原根。

PP rr kk gg
3 1 1 2
5 1 2 2
17 1 4 3
97 3 5 5
193 3 6 5
257 1 8 3
7681 15 9 17
12289 3 12 11
40961 5 13 3
65537 1 16 3
786433 3 18 10
5767169 11 19 3
7340033 7 20 3
23068673 11 21 3
104857601 25 22 3
167772161 5 25 3
469762049 7 26 3
1004535809 479 21 3
2013265921 15 27 31
2281701377 17 27 3
3221225473 3 30 5
75161927681 35 31 3
77309411329 9 33 7
206158430209 3 36 22
2061584302081 15 37 7
2748779069441 5 39 3
6597069766657 3 41 5
39582418599937 9 42 5
79164837199873 9 43 5
263882790666241 15 44 7
1231453023109121 35 45 3
1337006139375617 19 46 3
3799912185593857 27 47 5
4222124650659841 15 48 19
7881299347898369 7 50 6
31525197391593473 7 52 3
180143985094819841 5 55 6
1945555039024054273 27 56 5
4179340454199820289 29 57 3

为什么用原根

NTT 中把所有的 ωnk\omega_n^k 全部替换成了 g(P1)kng^{\frac{(P - 1)k}n}
为什么可以呢?
因为 FFT 中用到的单位根的性质原根都满足。

参考代码

题目同上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 1 << 21;
const long long mod = 998244353;
const long long G = 3;
const long long Gi = 332748118;
int lena,lenb,n = 1,lg;
long long fpow(long long a,long long b)
{
long long ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = ret * a % mod),a = a * a % mod;
return ret;
}
long long a[N + 5],b[N + 5],omg[N + 5],inv[N + 5];
void ntt(long long *a,long long *omg)
{
for(register int i = 0;i < n;++i)
{
int t = 0;
for(register int j = 0;j < lg;++j)
if(i & (1 << j))
t |= (1 << lg - j - 1);
if(i < t)
swap(a[i],a[t]);
}
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
long long t = omg[n / w * j] * a[i + j + m] % mod;
a[i + j + m] = (a[i + j] - t + mod) % mod,a[i + j] = (a[i + j] + t) % mod;
}
}
int main()
{
scanf("%d%d",&lena,&lenb);
++lena,++lenb;
for(;n < lena + lenb;n <<= 1,++lg);
for(register int i = 0;i < n;++i)
omg[i] = fpow(G,(mod - 1) / n * i),inv[i] = fpow(Gi,(mod - 1) / n * i);
int x;
for(register int i = 0;i < lena;++i)
scanf("%d",&x),a[i] = x;
for(register int i = 0;i < lenb;++i)
scanf("%d",&x),b[i] = x;
ntt(a,omg),ntt(b,omg);
for(register int i = 0;i < n;++i)
a[i] *= b[i];
ntt(a,inv);
long long n_inv = fpow(n,mod - 2);
for(register int i = 0;i < lena + lenb - 1;++i)
printf("%lld ",a[i] * n_inv % mod);
}

参考文献

arknights