洛谷 5439.「X Round 2」永恒

蒟蒻不会边分治……
诶一看这个时限 O(nlog2n)O(n \log^2 n) 很有戏啊……
所以?树剖?

首先我们记永恒的树为 TT',Trie 为 TT,以及树上的一些相关值带 ' 表示永恒的树意义的,否则表示在 Trie 上的。
题目求的是

u,vT,u<v[x,y][u,v]depLCA(x,y)\sum\limits_{u,v \in T',u < v} \sum\limits_{[x,y] \subseteq [u,v]} dep_{\text{LCA}(x,y)}

(此处 depdep00 开始计算)

然后有一个简单的思想就是把后面的拉到前面来,即对于 x,yx,y 计算 depLCA(x,y)dep_{\text{LCA(x,y)}} 并求其出现次数(作为子路径的次数)。
这个就是套路了。

考虑当两者没有祖孙关系时,那么贡献显然就是 sizexsizeysize'_x \cdot size'_y

否则,假设 xxyy 的祖先,zzxyx \to y 路径上除 xx 以外离 xx 最近的点。
此时贡献是 (nsizez)sizey(n - size'_z) size'_y

第一种贡献其实同样比较套路,看到 LCA\text{LCA}depdep 就想到了「『LNOI2014』LCA」。
所以类似地,枚举 xx,并让其到根路径都加上 sizexsize'_x
枚举 yy,对根路径求和,乘上 sizeysize'_y
注意要除掉根的贡献,因为 depdep00 开始记。
那么离线一下可以做到 O(n)O(n)

有一个潜在的问题:第二种贡献的点在第一种贡献里也被计算了。
这个先不讨论。

第二种贡献,我们跑一遍 DFS。因为只要确定 xx 了,yy 一定在 xx 的子树当中。
假设当前遍历到的点为 xx,枚举其儿子为 zz,那么对于 zz 子树之内的 yy 计算时都是同一个 zz
所以我们就给 xx 到根加上 nsizezn - size'_z,然后每次遍历到点就计算贡献。
注意回溯时要减回去。

然后再来讨论刚才那个问题,我们发现,只需要在 DFS 的过程中把 sizexsize'_x 对于子树内的贡献除掉即可。
就是说把加的数换成 nsizezsizexn - size'_z - size'_x

最近复习了下区间查改的树状数组,常数又小,换上了。
反正有取模不会爆。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <cstdio>
#include <algorithm>
#include <vector>
#define lowbit(x) ((x) & -(x))
using namespace std;
const int N = 3e5;
const int M = 3e5;
const long long mod = 998244353;
const long long inv = 499122177;
int n,m;
int d[N + 5];
vector<int> e[N + 5];
int siz[N + 5];
int to[M + 5],pre[M + 5],first[M + 5];
long long ans;
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
int fa[M + 5],sz[M + 5],dep[M + 5],son[M + 5],top[M + 5],id[M + 5];
long long sum[M + 5],w[M + 5];
void dfs(int p,int fa)
{
siz[p] = 1;
for(register int i = 0;i < e[p].size();++i)
if(e[p][i] ^ fa)
dfs(e[p][i],p),siz[p] += siz[e[p][i]];
}
void dfs1(int p)
{
sz[p] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dep[to[i]] = dep[p] + 1,dfs1(to[i]),sz[p] += sz[to[i]],w[p] = (w[p] + w[to[i]]) % mod;
if(!son[p] || sz[to[i]] > sz[son[p]])
son[p] = to[i];
}
}
void dfs2(int p)
{
static int tot = 0;
id[p] = ++tot;
if(son[p])
top[son[p]] = top[p],dfs2(son[p]);
for(register int i = first[p];i;i = pre[i])
if(!id[to[i]])
top[to[i]] = to[i],dfs2(to[i]);
}
void change(int p,long long k)
{
while(p)
sum[id[top[p]]] = (sum[id[top[p]]] + k) % mod,sum[id[p] + 1] = ((sum[id[p] + 1] - k) % mod + mod) % mod,p = fa[top[p]];
}
void get(int p,long long k)
{
while(p)
ans = (ans + ((sum[id[p]] - sum[id[top[p]] - 1]) % mod + mod) % mod * k % mod) % mod,p = fa[top[p]];
}
long long c[2][N + 5];
inline void add(int op,int x,long long k)
{
for(;x <= m;x += lowbit(x))
c[op][x] = (c[op][x] + k) % mod;
}
inline long long ask(int op,int x)
{
long long ret = 0;
for(;x;x -= lowbit(x))
ret = (ret + c[op][x]) % mod;
return ret;
}
inline void update(int l,int r,long long k)
{
add(0,l,k % mod),add(0,r + 1,(mod - k % mod) % mod),add(1,l,l * k % mod),add(1,r + 1,(mod - r - 1) * k % mod);
}
inline long long query(int l,int r)
{
return (((r + 1) * ask(0,r) % mod - ask(1,r) + mod) % mod - (l * ask(0,l - 1) % mod - ask(1,l - 1) + mod) % mod + mod) % mod;
}
void modify(int p,long long k)
{
while(p)
update(id[top[p]],id[p],k),p = fa[top[p]];
}
void answer(int p,long long k)
{
while(p)
ans = (ans + query(id[top[p]],id[p]) * k % mod) % mod,p = fa[top[p]];
}
void calc(int p,int fa)
{
answer(d[p],siz[p]);
for(register int i = 0;i < e[p].size();++i)
if(e[p][i] ^ fa)
modify(d[p],(n - siz[e[p][i]] - siz[p] + mod) % mod),calc(e[p][i],p),modify(d[p],(siz[e[p][i]] + siz[p] - n + mod) % mod);
}
int main()
{
scanf("%d%d",&n,&m);
int u,p;
for(register int i = 1;i <= n;++i)
scanf("%d",&u),u ? (e[u].push_back(i),0) : (p = i);
dfs(p,0);
for(register int i = 1;i <= m;++i)
scanf("%d",fa + i),fa[i] && (add(fa[i],i),1);
scanf("%*s");
for(register int i = 1;i <= n;++i)
scanf("%d",d + i);
for(register int i = 1;i <= n;++i)
w[d[i]] = (w[d[i]] + siz[i]) % mod;
for(register int i = first[1];i;i = pre[i])
fa[to[i]] = 0,dep[to[i]] = 1,top[to[i]] = to[i],dfs1(to[i]),dfs2(to[i]);
for(register int i = 1;i <= m;++i)
sum[id[i]] = w[i];
for(register int i = 1;i <= m;++i)
sum[i] = (sum[i] + sum[i - 1]) % mod;
for(register int i = 1;i <= n;++i)
get(d[i],siz[i]),ans = ((ans - (long long)siz[i] * siz[i] % mod * dep[d[i]] % mod) % mod + mod) % mod;
ans = ans * inv % mod;
calc(p,0);
printf("%lld\n",ans);
}