LibreOJ 3312.「ZJOI2020」传统艺能

老线段树了……

考虑期望的线性性,拆成每个结点的概率。
\(f,g\) 分别为一个结点被打标记,其祖先被打标记的概率。
对于结点 \([l,r]\),设其父亲为 \([L,R]\),考虑以下五种不同的修改区间:

  1. 不进入其父亲。
  2. 进入其父亲,并在该结点上打了标记。
  3. 在其除了自身以外的祖先上打了标记。
  4. 将父亲的标记下推。
  5. 进入了该结点。

令其概率分别为 \(p_1,p_2,p_3,p_4,p_5\),则显然转移为 \[ (g,f) \to ((p_1+p_4)g + p_2 + p_3,(p_1+p_3)f + p_4g + p_2) \]

定义变换 \[ (g,f) \xrightarrow{(a,b,c,d,e)} (ag+b,cg+df+e) \]

考虑两个变换的复合,做一些 dirty works 易知其对原形式封闭,即具有结合律,具体地, \[ \begin{aligned} &\quad\,\, (a_0,b_0,c_0,d_0,e_0) \circ (a_1,b_1,c_1,d_1,e_1) \\ &= (a_0a_1,b_0a_1+b_1,a_0c_1+c_0d_1,d_0d_1,b_0c_1+e_0d_1+e_1) \end{aligned} \]

然后可以做变换的快速幂。
时间复杂度 \(O(n \log k)\)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 2e5;
const int mod = 998244353;
int n,k,inv;
int ans;
struct trans
{
int a,b,c,d,e;
inline trans(int x1 = 1,int x2 = 0,int x3 = 0,int x4 = 1,int x5 = 0)
{
a = x1,b = x2,c = x3,d = x4,e = x5;
}
inline trans operator*(const trans &o) const
{
return trans((long long)a * o.a % mod,((long long)b * o.a + o.b) % mod,((long long)a * o.c + (long long)c * o.d) % mod,(long long)d * o.d % mod,((long long)b * o.c + (long long)e * o.d + o.e) % mod);
}
};
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
inline trans fpow(trans a,int b)
{
trans ret;
for(;b;b >>= 1)
(b & 1) && (ret = ret * a,1),a = a * a;
return ret;
}
inline int C(int n)
{
return ((long long)n * (n + 1) >> 1) % mod;
}
void build(int l,int r,int L,int R)
{
if(l == 1 && r == n)
ans = (ans + inv) % mod;
else
{
int p1 = (long long)(C(L - 1) + C(n - R)) * inv % mod,
p2 = ((long long)(R - r) * l + (long long)(l - L) * (n - r + 1)) % mod * inv % mod,
p3 = (long long)L * (n - R + 1) % mod * inv % mod,
p4 = ((long long)(L + l - 1) * (l - L) + (long long)(n - R + 1 + n - r) * (R - r) >> 1) % mod * inv % mod;
ans = (ans + fpow(trans((p1 + p4) % mod,(p2 + p3) % mod,p4,(p1 + p3) % mod,p2),k).e) % mod;
}
if(l == r)
return ;
int mid;
scanf("%d",&mid);
build(l,mid,l,r),build(mid + 1,r,l,r);
}
int main()
{
scanf("%d%d",&n,&k),inv = fpow(C(n),mod - 2);
build(1,n,0,0);
printf("%d\n",ans);
}