「任意模数 NTT 及其优化」学习笔记

myy txdy!

三模 NTT

这还讲啥呢……
都 0202 年了还有人写三模 NTT 吗……

拆系数 FFT

整一个阈值 WW,把系数拆成 aW+c (c<W)aW + c\ (c < W) 的形式。
一般设 W=215=32768W = 2^{15} = 32768

那么需要 7 次 DFT,起码比隔壁 9 次的三模好(虽然人家是 NTT)

三次变两次

若我们有俩整系数多项式 A,BA,B 要求它们的卷积,我们可以设

显然这俩是共轭的。

如果我们能求出 F,GF,G 的点值表达,也就能当成二元方程来解出 A,BA,B 的点值表达。
然而,实际上,只需要求 FF 的点值表达,就能推出 GG 的点值表达。


再令 即共轭负数。
再用小写字母表示多项式的系数序列。

注意当 k=0k=0

将该优化应用于拆系数 FFT

若有两个数 aW+baW + bcW+dcW + d 相乘,则结果应为 acW2+(ad+bc)W+bdacW^2 + (ad+bc)W + bd
于是 7 次 DFT 做完(就是上面那个朴素的拆系数 FFT)

但是,注意到

于是我们把 a+bi,abi,c+dia+bi,a-bi,c+di 分别 DFT 然后这样卷起来再 IDFT 回来,求出 ac,ad+bc,bdac,ad+bc,bd 就行了。
这样要 5 次。

然而 ,于是运用上面那个就可以少一次。
然后就变成 4 次了。
(听说还有 3.5 次然而我不会)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 1 << 18;
const long double pi = acos(-1);
const int W = 1 << 15;
int lena,lenb,mod,n;
int lg2[N + 5],rev[N + 5];
int ans[N + 5];
struct cp
{
long double a,b;
inline void operator+=(const cp &o)
{
a += o.a,b += o.b;
}
inline cp operator+(const cp &o) const
{
return (cp){a + o.a,b + o.b};
}
inline cp operator-(const cp &o) const
{
return (cp){a - o.a,b - o.b};
}
inline cp operator*(const cp &o) const
{
return (cp){a * o.a - b * o.b,a * o.b + b * o.a};
}
inline cp operator*(const double &o) const
{
return (cp){a * o,b * o};
}
inline cp operator~() const
{
return (cp){a,-b};
}
} f[N + 5],g[N + 5],h[N + 5],a[N + 5],b[N + 5],rt[N + 5];
inline void init(int len)
{
for(n = 1;n < len;n <<= 1);
for(register int i = 2;i <= n;++i)
lg2[i] = lg2[i >> 1] + 1;
rt[n >> 1] = (cp){1,0};
for(register int i = 1;i <= (n >> 1);++i)
rt[(n >> 1) + i] = (cp){cos(2 * pi * i / n),sin(2 * pi * i / n)};
for(register int i = (n >> 1) - 1;i;--i)
rt[i] = rt[i << 1];
}
inline void fft(cp *a,int type,int n)
{
type == -1 && (reverse(a + 1,a + n),1);
int lg = lg2[n] - 1;
for(register int i = 0;i < n;++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg),
i < rev[i] && (swap(a[i],a[rev[i]]),1);
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
cp t = rt[m | j] * a[i | j | m];
a[i | j | m] = a[i | j] - t,a[i | j] += t;
}
if(type == -1)
for(register int i = 0;i < n;++i)
a[i].a /= n,a[i].b /= n;
}
int main()
{
scanf("%d%d%d",&lena,&lenb,&mod),init(max(++lena,++lenb) << 1);
int x;
for(register int i = 0;i < lena;++i)
scanf("%d",&x),f[i] = (cp){x / W,x % W};
for(register int i = 0;i < lenb;++i)
scanf("%d",&x),g[i] = (cp){x / W,x % W};
fft(f,1,n),fft(g,1,n);
for(register int i = 0;i < n;++i)
h[i] = ~f[(n - i) % n];
for(register int i = 0;i < n;++i)
a[i] = f[i] * g[i],b[i] = g[i] * h[i];
fft(a,-1,n),fft(b,-1,n);
for(register int i = 0;i < n;++i)
{
long long ac = (a[i].a + b[i].a) / 2 + 0.5;
long long bd = b[i].a - ac + 0.5;
long long bcad = a[i].b + 0.5;
ans[i] = ((ac % mod * W % mod * W % mod) % mod + (bcad % mod * W % mod) % mod + bd % mod) % mod;
}
for(register int i = 0;i < lena + lenb - 1;++i)
printf("%d ",ans[i]);
}