JZOJ 5666 法力风暴

\(a_i = A_i\)
注意到 \(\prod\limits_{i\ne x} a_i = \prod\limits_{i=1}^n a_i - \prod\limits_{i=1}^n (a_i - [i=x])\),故转化为求 \(E(\prod\limits_{i=1}^n a'_i)\),其中 \(a'_i\) 表示 \(a_i\) 所有操作后的值。

\(b_i = a_i - a'_i\),则有 \[ \begin{align*} E(\prod\limits_{i=1}^n(a_i - b_i)) &=\frac1{n^k}\sum\limits_{b_1+b_2+\dots+b_n=k}\frac{k!}{\prod\limits_{i=1}^n b_i!}\prod\limits_{i=1}^n(a_i - b_i) \\ &=\frac{k!}{n^k}\sum\limits_{b_1+b_2+\dots+b_n=k}\prod\limits_{i=1}^n\frac{a_i-b_i}{b_i!} \\ \end{align*} \]

如果有点生成函数知识的话,容易发现这很像某些生成函数之积的 \(x^k\) 项系数。
设指数生成函数 \(G_i(x)=\sum\limits_{j=1}^{\infty}\frac{a_i-j}{j!}x^j\)
容易发现 \(G_i(x)=e^x(a_i-x)\)
\(\sum\limits_{b_1+b_2+\dots+b_n=k}\prod\limits_{i=1}^n\frac{a_i-b_i}{b_i!}=(\prod\limits_{i=1}^nG_i(x))[x^k]\)

然后发现 \(\prod\limits_{i=1}^n G_i(x)=e^{nx}\prod\limits_{i=1}^n(a_i-x)\)
考虑用普通的分治思想 + NTT 来求 \(\prod\limits_{i=1}^n(a_i-x)\)(并不是 CDQ 分治)。
(大常数选手被卡)

又因为 \(e^{nx}=\sum\limits_{i=0}^{\infty}\frac{n^ix^i}{i!}\),得 \((\prod\limits_{i=1}^nG_i(x))[x^k] = \sum\limits_{i=0}^n \frac{n^{k-i}}{(k-i)!}f_i\),其中 \(f_i = (\prod\limits_{i=1}^n(a_i-x))[x^i]\)
看起来是不能做的,但是本来还有个系数 \(\frac{k!}{n^k}\)
\[ \begin{align*} \frac{k!}{n^k}\sum\limits_{i=0}^n \frac{n^{k-i}}{(k-i)!}f_i &=\sum\limits_{i=0}^n \frac{k!n^k}{(k-i)!}f_i \end{align*} \]

注意到 \(\frac{k!}{(k-i)!}\) 只有 \(i\) 项,于是直接算即可(其实就是下降幂)
实际上 \(\prod\limits_{i=1}^n a_i = f_0\),故要求的为 \(-\sum\limits_{i=1}^n \frac{k!n^k}{(k-i)!}f_i\)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#pragma GCC optimize("Ofast")
#pragma GCC target("sse3","sse2","sse")
#pragma GCC diagnostic error "-std=c++14"
#pragma GCC diagnostic error "-fwhole-program"
#pragma GCC diagnostic error "-fcse-skip-blocks"
#pragma GCC diagnostic error "-funsafe-loop-optimizations"
#pragma GCC optimize("fast-math","unroll-loops","no-stack-protector","inline")
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 1 << 18;
const int mod = 998244353;
const int G = 3;
const int Gi = 332748118;
int len,n = 1,k;
int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int a[N + 5],f[N + 5],omg[N + 5],inv[N + 5];
int ans;
void ntt(int *a,int *omg,int n,int lg)
{
for(register int i = 0;i < n;++i)
{
int t = 0;
for(register int j = 0;j < lg;++j)
if(i & (1 << j))
t |= (1 << lg - j - 1);
if(i < t)
swap(a[i],a[t]);
}
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
int t = (long long)omg[n / w * j] * a[i + j + m] % mod;
a[i + j + m] = (a[i + j] - t + mod) % mod,a[i + j] = (a[i + j] + t) % mod;
}
}
void mul(int *a,int *b,int len)
{
int lg = 0,n = 1;
for(;n < len * 2;n <<= 1,++lg);
for(register int i = 0;i < n;++i)
omg[i] = fpow(G,(mod - 1) / n * i),inv[i] = fpow(Gi,(mod - 1) / n * i);
ntt(a,omg,n,lg),ntt(b,omg,n,lg);
for(register int i = 0;i < n;++i)
a[i] = (long long)a[i] * b[i] % mod;
ntt(a,inv,n,lg);
int n_inv = fpow(n,mod - 2);
for(register int i = 0;i < n;++i)
a[i] = (long long)a[i] * n_inv % mod;
}
void solve(int l,int r,int L,int R)
{
if(l == r)
{
if(l < len)
f[L] = a[l] % mod,f[L + 1] = mod - 1;
else
f[L] = 1,f[L + 1] = 0;
return ;
}
int mid = l + r >> 1,MID = L + R >> 1;
solve(l,mid,L,MID);
solve(mid + 1,r,MID + 1,R);
static int buf[2][N + 5];
for(register int i = 0;i <= R - L;++i)
buf[0][i] = buf[1][i] = 0;
for(register int i = 0;i <= mid - l + 1;++i)
buf[0][i] = f[L + i],buf[1][i] = f[MID + 1 + i];
mul(buf[0],buf[1],mid - l + 2);
for(register int i = 0;i <= r - l + 1;++i)
f[L + i] = buf[0][i];
}
int main()
{
freopen("manastorm.in","r",stdin),freopen("manastorm.out","w",stdout);
scanf("%d%d",&len,&k),k %= mod;
for(;n < len;n <<= 1);
for(register int i = 0;i < len;++i)
scanf("%d",a + i);
solve(0,n - 1,0,2 * n - 1);
int n_inv = fpow(len,mod - 2);
for(register int i = 1,prod = (long long)n_inv * k % mod;i <= len;prod = (long long)prod * n_inv % mod * (k - i) % mod,++i)
ans = (ans - (long long)f[i] * prod % mod + mod) % mod;
printf("%d\n",ans);
}