「任意模数 NTT 及其优化」学习笔记

myy txdy!

三模 NTT

这还讲啥呢……
都 0202 年了还有人写三模 NTT 吗……

拆系数 FFT

整一个阈值 \(W\),把系数拆成 \(aW + c\ (c < W)\) 的形式。
一般设 \(W = 2^{15} = 32768\)

那么需要 7 次 DFT,起码比隔壁 9 次的三模好(虽然人家是 NTT)

三次变两次

若我们有俩整系数多项式 \(A,B\) 要求它们的卷积,我们可以设 \[ \newcommand{\dft}{\mathrm{DFT}} \newcommand{\conj}{\operatorname{conj}} F(x) = A(x) + iB(x),G(x) = A(x) - iB(x) \] 显然这俩是共轭的。

如果我们能求出 \(F,G\) 的点值表达,也就能当成二元方程来解出 \(A,B\) 的点值表达。
然而,实际上,只需要求 \(F\) 的点值表达,就能推出 \(G\) 的点值表达。

\(F_{\dft}(k) = F(\omega_n^k),G_{\dft}(k) = G(\omega_n^k)\)
再令 \(\conj(a+bi) = a-bi\) 即共轭负数。
再用小写字母表示多项式的系数序列。

\[ \begin{align*} F_{\dft}(k) &= A(\omega_n^k) + iB(\omega_n^k) \\ &= \sum\limits_{j=0}^{n-1} (a_j + ib_j)\omega_n^{jk} \\ G_{\dft}(k) &= \sum\limits_{j=0}^{n-1} (a_j - ib_j)\omega_n^{jk} \\ &= \sum\limits_{j=0}^{n-1} (a_j - ib_j)\left(\cos \frac{2\pi jk}n + i\sin \frac{2\pi jk}n\right) \\ &= \sum\limits_{j=0}^{n-1} \left(\left(a_j\cos\frac{2\pi jk}n + b_j\sin\frac{2\pi jk}n\right) + i\left(a_j\sin\frac{2\pi jk}n - b_j\cos\sin\frac{2\pi jk}n\right)\right) \\ &= \conj\left(\sum\limits_{j=0}^{n-1} \left(\left(a_j\cos\frac{2\pi jk}n + b_j\sin\frac{2\pi jk}n\right) - i\left(a_j\sin\frac{2\pi jk}n - b_j\cos\frac{2\pi jk}n\right)\right)\right) \\ &= \conj\left(\sum\limits_{j=0}^{n-1} \left(\left(a_j\cos\frac{-2\pi jk}n - b_j\sin\frac{-2\pi jk}n\right) + i\left(a_j\sin\frac{-2\pi jk}n + b_j\cos\frac{-2\pi jk}n\right)\right)\right) \\ &= \conj\left(\sum\limits_{j=0}^{n-1} (a_j + ib_j)\left(\cos\frac{-2\pi jk}n + i\sin\frac{-2\pi jk}n\right)\right) \\ &= \conj\left(\sum\limits_{j=0}^{n-1} (a_j + ib_j)\omega_n^{-jk}\right) \\ &= \conj\left(\sum\limits_{j=0}^{n-1} (a_j + ib_j)\omega_n^{(n-k)j}\right) \\ &= \conj(F_{\dft}(n - k)) \end{align*} \]

注意当 \(k=0\)\(n - k \equiv 0 \pmod n\)

将该优化应用于拆系数 FFT

若有两个数 \(aW + b\)\(cW + d\) 相乘,则结果应为 \(acW^2 + (ad+bc)W + bd\)
于是 7 次 DFT 做完(就是上面那个朴素的拆系数 FFT)

但是,注意到 \[ (a+bi)(c+di)=(ac-bd)+(bc+ad)i \\ (a-bi)(c+di)=(ac+bd)+(ad-bc)i \]

于是我们把 \(a+bi,a-bi,c+di\) 分别 DFT 然后这样卷起来再 IDFT 回来,求出 \(ac,ad+bc,bd\) 就行了。
这样要 5 次。

然而 \(\conj(a+bi)=a-bi\),于是运用上面那个就可以少一次。
然后就变成 4 次了。
(听说还有 3.5 次然而我不会)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 1 << 18;
const long double pi = acos(-1);
const int W = 1 << 15;
int lena,lenb,mod,n;
int lg2[N + 5],rev[N + 5];
int ans[N + 5];
struct cp
{
long double a,b;
inline void operator+=(const cp &o)
{
a += o.a,b += o.b;
}
inline cp operator+(const cp &o) const
{
return (cp){a + o.a,b + o.b};
}
inline cp operator-(const cp &o) const
{
return (cp){a - o.a,b - o.b};
}
inline cp operator*(const cp &o) const
{
return (cp){a * o.a - b * o.b,a * o.b + b * o.a};
}
inline cp operator*(const double &o) const
{
return (cp){a * o,b * o};
}
inline cp operator~() const
{
return (cp){a,-b};
}
} f[N + 5],g[N + 5],h[N + 5],a[N + 5],b[N + 5],rt[N + 5];
inline void init(int len)
{
for(n = 1;n < len;n <<= 1);
for(register int i = 2;i <= n;++i)
lg2[i] = lg2[i >> 1] + 1;
rt[n >> 1] = (cp){1,0};
for(register int i = 1;i <= (n >> 1);++i)
rt[(n >> 1) + i] = (cp){cos(2 * pi * i / n),sin(2 * pi * i / n)};
for(register int i = (n >> 1) - 1;i;--i)
rt[i] = rt[i << 1];
}
inline void fft(cp *a,int type,int n)
{
type == -1 && (reverse(a + 1,a + n),1);
int lg = lg2[n] - 1;
for(register int i = 0;i < n;++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg),
i < rev[i] && (swap(a[i],a[rev[i]]),1);
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
cp t = rt[m | j] * a[i | j | m];
a[i | j | m] = a[i | j] - t,a[i | j] += t;
}
if(type == -1)
for(register int i = 0;i < n;++i)
a[i].a /= n,a[i].b /= n;
}
int main()
{
scanf("%d%d%d",&lena,&lenb,&mod),init(max(++lena,++lenb) << 1);
int x;
for(register int i = 0;i < lena;++i)
scanf("%d",&x),f[i] = (cp){x / W,x % W};
for(register int i = 0;i < lenb;++i)
scanf("%d",&x),g[i] = (cp){x / W,x % W};
fft(f,1,n),fft(g,1,n);
for(register int i = 0;i < n;++i)
h[i] = ~f[(n - i) % n];
for(register int i = 0;i < n;++i)
a[i] = f[i] * g[i],b[i] = g[i] * h[i];
fft(a,-1,n),fft(b,-1,n);
for(register int i = 0;i < n;++i)
{
long long ac = (a[i].a + b[i].a) / 2 + 0.5;
long long bd = b[i].a - ac + 0.5;
long long bcad = a[i].b + 0.5;
ans[i] = ((ac % mod * W % mod * W % mod) % mod + (bcad % mod * W % mod) % mod + bd % mod) % mod;
}
for(register int i = 0;i < lena + lenb - 1;++i)
printf("%d ",ans[i]);
}