LibreOJ 2058 「TJOI / HEOI2016」求和

\(f_n = \sum\limits_{i=0}^n 2^i \cdot i! {n \brace i}\)
考虑其组合意义,即 \(f_n\) 表示 \(n\) 个不同的物品放进最多 \(n\) 个不同的盒子中,可以有空盒,且每个盒子有黑红两种颜色的方案数。

于是考虑枚举其中一个盒子中物品的个数,易得递推式 \(f_n = 2 \sum\limits_{i=1}^n \binom n i f_{n-i}\)
推一推 \[ \begin{align*} f_n &= 2 \sum\limits_{i=1}^n \binom n i f_{n-i} \\ f_n &= 2 \sum\limits_{i=1}^n \frac{n!}{i!(n-i)!} f_{n-i} \\ \frac{f_n}{2 \cdot n!} &= \sum\limits_{i=1}^n \frac 1{i!} \cdot \frac{f_{n-i}}{(n-i)!} \end{align*} \]

熟悉的卷积形式(其实也不是很熟悉)
于是结合分治 NTT 就有了一个垃圾的 \(O(n \log^2 n)\) 解法……
似乎还有一个 \(\log\) 的解法和线性的解法……

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 1 << 18;
const int mod = 998244353;
const int G = 3;
const int Gi = 332748118;
int len,n = 1;
int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int f[N + 5],omg[N + 5],inv[N + 5];
int fac[N + 5],ifac[N + 5];
int ans;
void ntt(int *a,int *omg,int n,int lg)
{
for(register int i = 0;i < n;++i)
{
int t = 0;
for(register int j = 0;j < lg;++j)
if(i & (1 << j))
t |= (1 << lg - j - 1);
if(i < t)
swap(a[i],a[t]);
}
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
int t = (long long)omg[n / w * j] * a[i + j + m] % mod;
a[i + j + m] = (a[i + j] - t + mod) % mod,a[i + j] = (a[i + j] + t) % mod;
}
}
void mul(int *a,int *b,int n)
{
int lg = 0;
for(register int i = 1;i < n;i <<= 1,++lg);
for(register int i = 0;i < n;++i)
omg[i] = fpow(G,(mod - 1) / n * i),inv[i] = fpow(Gi,(mod - 1) / n * i);
ntt(a,omg,n,lg),ntt(b,omg,n,lg);
for(register int i = 0;i < n;++i)
a[i] = (long long)a[i] * b[i] % mod;
ntt(a,inv,n,lg);
int n_inv = fpow(n,mod - 2);
for(register int i = 0;i < n;++i)
a[i] = (long long)a[i] * n_inv % mod;
}
void solve(int l,int r)
{
if(l == r)
{
f[l] = 2LL * f[l] * fac[l] % mod;
return ;
}
int mid = l + r >> 1;
solve(l,mid);
static int buf[2][N + 5];
for(register int i = 0;i <= mid - l;++i)
buf[0][i] = (long long)f[i + l] * ifac[i + l] % mod;
for(register int i = mid - l + 1;i <= r - l;++i)
buf[0][i] = 0;
for(register int i = 0;i <= r - l;++i)
buf[1][i] = (long long)ifac[i] % mod;
mul(buf[0],buf[1],r - l + 1);
for(register int i = mid + 1;i <= r;++i)
f[i] = (f[i] + buf[0][i - l]) % mod;
solve(mid + 1,r);
}
int main()
{
scanf("%d",&len);
for(;n < len;n <<= 1);
fac[0] = 1;
for(register int i = 1;i < n;++i)
fac[i] = (long long)fac[i - 1] * i % mod;
ifac[n - 1] = fpow(fac[n - 1],mod - 2);
for(register int i = n - 1;i;--i)
ifac[i - 1] = (long long)ifac[i] * i % mod;
f[0] = (mod + 1) / 2;
solve(0,n - 1);
for(register int i = 0;i <= len;++i)
ans = (ans + f[i]) % mod;
printf("%d\n",ans);
}