LibreOJ 3046 「ZJOI2019」语言

第一反应:树剖之后树套树维护……
估计是 \(O(n \log^4 n)\) 的(

然而并不是在线的,并且只有一次询问。
如果按上述做法需要把一条路径在两个维度都用树剖划分为几个区间,
但是其中一个维度可以换成树上差分,于是第二个维度变成线段树合并。

线段树使用类似扫描线的写法,避免重复统计。
注意要添加 \(n\) 个不存在的 \((i,i)\) 操作,是为了方便计算答案。
否则不容易直接计算满足 \(u = v\)\((u,v)\) 点对的个数。

此做法是 \(O(n \log^2 n)\) 的,我太弱了不会一个 \(\log\) 的。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 1e5;
int n,m;
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v;
pre[tot] = first[u];
first[u] = tot;
}
int fa[N + 5],dep[N + 5],sz[N + 5],son[N + 5],top[N + 5],id[N + 5];
void dfs1(int p)
{
sz[p] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dep[to[i]] = dep[p] + 1,dfs1(to[i]),sz[p] += sz[to[i]];
if(!son[p] || sz[to[i]] > sz[son[p]])
son[p] = to[i];
}
}
void dfs2(int p)
{
static int tot = 0;
id[p] = ++tot;
if(son[p])
top[son[p]] = top[p],dfs2(son[p]);
for(register int i = first[p];i;i = pre[i])
if(!top[to[i]])
top[to[i]] = to[i],dfs2(to[i]);
}
struct node
{
int cnt,len;
int ls,rs;
} seg[(N << 8) + 10];
int rt[N + 5];
inline void up(int p,int tl,int tr)
{
if(seg[p].cnt > 0)
seg[p].len = tr - tl + 1;
else if(tl == tr)
seg[p].len = 0;
else
seg[p].len = seg[seg[p].ls].len + seg[seg[p].rs].len;
}
void update(int l,int r,int k,int &p,int tl,int tr)
{
static int tot = 0;
if(!p)
p = ++tot;
if(l <= tl && tr <= r)
{
seg[p].cnt += k,up(p,tl,tr);
return ;
}
int mid = tl + tr >> 1;
if(l <= mid)
update(l,r,k,seg[p].ls,tl,mid);
if(r > mid)
update(l,r,k,seg[p].rs,mid + 1,tr);
up(p,tl,tr);
}
int merge(int p,int q,int tl,int tr)
{
if(!p || !q)
return p | q;
seg[p].cnt += seg[q].cnt;
int mid = tl + tr >> 1;
seg[p].ls = merge(seg[p].ls,seg[q].ls,tl,mid);
seg[p].rs = merge(seg[p].rs,seg[q].rs,mid + 1,tr);
up(p,tl,tr);
return p;
}
inline int getlca(int x,int y)
{
while(top[x] ^ top[y])
dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]];
return dep[x] < dep[y] ? x : y;
}
void update(int x,int y)
{
int u = x,v = y;
int lca = getlca(x,y);
while(top[x] ^ top[y])
dep[top[x]] > dep[top[y]] ? (update(id[top[x]],id[x],1,rt[u],1,n),update(id[top[x]],id[x],1,rt[v],1,n),update(id[top[x]],id[x],-1,rt[lca],1,n),update(id[top[x]],id[x],-1,rt[fa[lca]],1,n),x = fa[top[x]]) : (update(id[top[y]],id[y],1,rt[u],1,n),update(id[top[y]],id[y],1,rt[v],1,n),update(id[top[y]],id[y],-1,rt[lca],1,n),update(id[top[y]],id[y],-1,rt[fa[lca]],1,n),y = fa[top[y]]);
if(dep[x] > dep[y])
swap(x,y);
update(id[x],id[y],1,rt[u],1,n),update(id[x],id[y],1,rt[v],1,n),update(id[x],id[y],-1,rt[lca],1,n),update(id[x],id[y],-1,rt[fa[lca]],1,n);
}
long long solve(int p)
{
long long ret = 0;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
ret += solve(to[i]),merge(rt[p],rt[to[i]],1,n);
ret += seg[rt[p]].len;
return ret;
}
int main()
{
scanf("%d%d",&n,&m);
int u,v;
for(register int i = 1;i < n;++i)
scanf("%d%d",&u,&v),add(u,v),add(v,u);
dep[1] = 1,top[1] = 1,dfs1(1),dfs2(1);
for(register int i = 1;i <= n;++i)
update(i,i);
while(m--)
scanf("%d%d",&u,&v),update(u,v);
printf("%lld\n",(solve(1) - n) / 2);
}