LibreOJ 3045.「ZJOI2019」开关

\(S = \sum\limits_{i=1}^n p_i\)

由于不同开关的操作之间有顺序,即有标号,考虑对每个开关的操作构造 EGF 并乘起来。
\[ F_i(x) = \sum\limits_{n=0}^{\infty} [n \equiv s_i \pmod 2] \frac{p_i^n}S \cdot \frac{x^n}{n!} \]\[ F(x) = \prod\limits_{i=1}^n F_i(x) \] \(F(x)\) 即序列 \(\{f_n\}\) 的 EGF,其中 \(f_n = n! [x^n]F(x)\) 表示 \(n\) 次达到指定状态的概率。

但是容易发现题目要求第一次达到指定状态的期望次数,考虑再构造一些东西。

\(g_n\) 表示 \(n\) 次关闭全部开关的概率,\(h_n\) 表示 \(n\) 次达到期望状态且是首次达到的概率。
\(f(x),g(x),h(x)\) 分别为 \(\{f_n\},\{g_n\},\{h_n\}\) 的 OGF,则容易发现 \(f(x) = h(x) \cdot g(x)\)\(h(x) = \frac{f(x)}{g(x)}\)

再根据一些基本知识,易知所谓期望步数即 \(\sum\limits_{n=0}^{\infty} n \cdot h_n = \sum\limits_{n=0}^{\infty} [x^n]h'(x) = h'(1)\)

考虑如何求答案。
首先易知 \[ \begin{aligned} F(x) &= \prod\limits_{i=1}^n \frac{\exp\left(\frac{p_i}S x\right) + (-1)^{s_i}\exp\left(-\frac{p_i}S x\right)}2 \\ G(x) &= \prod\limits_{i=1}^n \frac{\exp\left(\frac{p_i}S x\right) + \exp\left(-\frac{p_i}S x\right)}2 \end{aligned} \]

考虑把 \(F(x),G(x)\) 看做关于 \(\exp\left(\frac 1S x\right)\) 的多项式。
\[ \begin{aligned} F(x) &= \sum\limits_{i=-S}^S \mathcal F_i \exp\left(\frac iS x\right) \\ G(x) &= \sum\limits_{i=-S}^S \mathcal G_i \exp\left(\frac iS x\right) \end{aligned} \]

系数可以通过背包 DP 求出。

易得 \[ \begin{aligned} f(x) &= \sum\limits_{i=-S}^S \frac{\mathcal F_i}{1 - \frac iS x} \\ g(x) &= \sum\limits_{i=-S}^S \frac{\mathcal G_i}{1 - \frac iS x} \end{aligned} \]

那么根据基本知识 \(h'(x) = \frac{f'(x)g(x) + f(x)g'(x)}{g^2(x)}\),考虑求 \(f(1),f'(1),g(1),g'(1)\)
但是可惜的是它们并不收敛……

考虑把 \(f,g\) 乘上 \((1-x)\),再推一推,可得答案为 \[ \frac 1{\mathcal G_i^2} \sum\limits_{i=-S}^{S-1} \frac{(\mathcal F_i \mathcal G_S - \mathcal F_S \mathcal G_i)S}{i - S} \]

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <cstdio>
using namespace std;
const int N = 100;
const int S = 5e4;
const int mod = 998244353;
const int inv = 499122177;
int n,s[N + 5],p[N + 5],sum,isum,ans;
int fr[N + 5][(S << 1) + 5],gr[N + 5][(S << 1) + 5];
int *f[N + 5],*g[N + 5];
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int main()
{
scanf("%d",&n);
for(register int i = 1;i <= n;++i)
scanf("%d",s + i),s[i] = s[i] ? mod - 1 : 1;
for(register int i = 1;i <= n;++i)
scanf("%d",p + i),sum += p[i];
for(register int i = 0;i <= n;++i)
f[i] = fr[i] + S,g[i] = gr[i] + S;
f[0][0] = g[0][0] = 1;
for(register int i = 1;i <= n;++i)
for(register int j = -sum;j <= sum;++j)
j + p[i] <= sum && (f[i][j + p[i]] = (f[i][j + p[i]] + (long long)f[i - 1][j] * inv) % mod,g[i][j + p[i]] = (g[i][j + p[i]] + (long long)g[i - 1][j] * inv) % mod),
j - p[i] >= -sum && (f[i][j - p[i]] = (f[i][j - p[i]] + (long long)s[i] * f[i - 1][j] % mod * inv) % mod,g[i][j - p[i]] = (g[i][j - p[i]] + (long long)g[i - 1][j] * inv) % mod);
isum = fpow(sum,mod - 2);
for(register int i = -sum;i < sum;++i)
ans = (ans + ((long long)f[n][i] * g[n][sum] % mod - (long long)f[n][sum] * g[n][i] % mod + mod) * sum % mod * fpow((i - sum + mod) % mod,mod - 2)) % mod;
ans = (long long)ans * fpow(g[n][sum],mod - 3) % mod;
printf("%d\n",ans);
}