LibreOJ 2537.「PKUWC2018」Minimax

首先离散化。
考虑一个简单 \(O(n^2)\) 树形 DP:设 \(f_{u,i}\) 表示 \(u\) 取到 \(i\) 的概率,\(l,r\) 分别为 \(u\) 的左 / 右儿子。
那么有 \[ f_{u,i} = f_{l,i} \left[p_u \sum\limits_{j=1}^{i-1} f_{r,j} + (1-p_u) \sum\limits_{j=i+1}^m f_{r,j}\right] + f_{r,i} \left[p_u \sum\limits_{j=1}^{i-1} f_{l,j} + (1-p_u) \sum\limits_{j=i+1}^m f_{l,j}\right] \] 前后缀和优化即可。

事实上,这种形式的转移可以使用线段树合并来优化(即所谓「整体 DP」)。
由于没有相同值,合并到最后一定只剩一边非空。
在线段树合并的过程中不难算出两棵线段树上目前的前后缀和,直接打上乘法标记即可。

另外,\(10^{-4} \equiv 796898467 \pmod{998244353}\)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 3e5;
const int mod = 998244353;
const int inv1e4 = 796898467;
int n;
int fa[N + 5],ch[N + 5][2];
int a[N + 5],ind[N + 5],len;
int d[N + 5],ans;
namespace SEG
{
struct node
{
int sum,tag;
int ls,rs;
} seg[(N << 5) + 5];
inline void push(int p)
{
if(seg[p].tag ^ 1)
{
if(seg[p].ls)
seg[seg[p].ls].sum = (long long)seg[seg[p].ls].sum * seg[p].tag % mod,
seg[seg[p].ls].tag = (long long)seg[seg[p].ls].tag * seg[p].tag % mod;
if(seg[p].rs)
seg[seg[p].rs].sum = (long long)seg[seg[p].rs].sum * seg[p].tag % mod,
seg[seg[p].rs].tag = (long long)seg[seg[p].rs].tag * seg[p].tag % mod;
seg[p].tag = 1;
}
}
void insert(int x,int &p,int tl,int tr)
{
static int tot = 0;
!p && (seg[p = ++tot].tag = 1);
if(tl == tr)
{
seg[p].sum = 1;
return ;
}
push(p);
int mid = tl + tr >> 1;
x <= mid ? insert(x,seg[p].ls,tl,mid) : insert(x,seg[p].rs,mid + 1,tr);
seg[p].sum = (seg[seg[p].ls].sum + seg[seg[p].rs].sum) % mod;
}
int merge(int p,int q,int ptag,int qtag,int v)
{
if(!p && !q)
return 0;
if(!p)
{
seg[q].sum = (long long)seg[q].sum * qtag % mod,
seg[q].tag = (long long)seg[q].tag * qtag % mod;
return q;
}
if(!q)
{
seg[p].sum = (long long)seg[p].sum * ptag % mod,
seg[p].tag = (long long)seg[p].tag * ptag % mod;
return p;
}
push(p),push(q);
int plsum = seg[seg[p].ls].sum,prsum = seg[seg[p].rs].sum;
int qlsum = seg[seg[q].ls].sum,qrsum = seg[seg[q].rs].sum;
seg[p].ls = merge(seg[p].ls,seg[q].ls,(ptag + (long long)(1 - v + mod) * qrsum) % mod,(qtag + (long long)(1 - v + mod) * prsum) % mod,v),
seg[p].rs = merge(seg[p].rs,seg[q].rs,(ptag + (long long)v * qlsum) % mod,(qtag + (long long)v * plsum) % mod,v),
seg[p].sum = (seg[seg[p].ls].sum + seg[seg[p].rs].sum) % mod;
return p;
}
void dfs(int p,int tl,int tr)
{
if(tl == tr)
{
d[tl] = seg[p].sum;
return ;
}
push(p);
int mid = tl + tr >> 1;
seg[p].ls && (dfs(seg[p].ls,tl,mid),1);
seg[p].rs && (dfs(seg[p].rs,mid + 1,tr),1);
}
}
int rt[N + 5];
void dfs(int p)
{
if(ch[p][1])
dfs(ch[p][0]),dfs(ch[p][1]),
rt[p] = SEG::merge(rt[ch[p][0]],rt[ch[p][1]],0,0,a[p]);
else if(ch[p][0])
dfs(ch[p][0]),
rt[p] = rt[ch[p][0]];
else
SEG::insert(a[p],rt[p],1,len);
}
int main()
{
scanf("%d%*d",&n);
for(register int i = 2;i <= n;++i)
scanf("%d",fa + i),ch[fa[i]][(bool)ch[fa[i]][0]] = i;
for(register int i = 1;i <= n;++i)
scanf("%d",a + i),ch[i][0] ? (a[i] = (long long)a[i] * inv1e4 % mod) : (ind[++len] = a[i]);
sort(ind + 1,ind + len + 1),len = unique(ind + 1,ind + len + 1) - ind - 1;
for(register int i = 1;i <= n;++i)
!ch[i][0] && (a[i] = lower_bound(ind + 1,ind + len + 1,a[i]) - ind);
dfs(1),SEG::dfs(rt[1],1,len);
for(register int i = 1;i <= len;++i)
ans = (ans + (long long)i * ind[i] % mod * d[i] % mod * d[i]) % mod;
printf("%d\n",ans);
}