LibreOJ 6073 「2017 山东一轮集训 Day5」距离

这题主要是用到了一个非常妙的 trick——
求一个点集中所有点与一个另外钦定的点的距离和。

首先可以把两点距离转化为到根的距离相减,即 \(dis_x + dis_y - 2dis_{lca}\)
那么在整个式子里,这个钦定的点的 \(dis\) 是固定的,另一个 \(dis\) 也比较好求。
发现 \(dis_{lca}\) 的部分就是两条到根路径的交。
所以可以对于这个点集里所有点都对其到根路径每条边覆盖次数加一,然后从钦定点到根计算贡献,一条边的贡献是覆盖次数乘边权。

那么这个题就好办了,我们把询问也换成到根相减的形式,那么问题变成 \[\sum\limits_{i \in \text{path}(root,u)} \text{dist}(p_i,k)\] 这就是上面那个 trick 了。

可以在一开始用主席树对于每个点 \(u\) 维护所有 \(p_i\) 到根路径覆盖次数加一的结果,其中 \(i \in \text{path}(root,u)\)
然后询问就在对应的主席树上计算到根的贡献即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <cstdio>
#include <utility>
using namespace std;
const int N = 2e5;
int type,n,q;
int P[N + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],val[(N << 1) + 5],first[N + 5];
inline void add(int u,int v,int w)
{
static int tot = 0;
to[++tot] = v,val[tot] = w,pre[tot] = first[u],first[u] = tot;
}
int a[N + 5];
int fa[N + 5],dep[N + 5],sz[N + 5],son[N + 5],top[N + 5],id[N + 5],rk[N + 5];
long long dis[N + 5],sum[N + 5],lastans;
void dfs1(int p)
{
sz[p] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dep[to[i]] = dep[p] + 1,a[to[i]] = val[i],dis[to[i]] = dis[p] + val[i],dfs1(to[i]),sz[p] += sz[to[i]];
if(!son[p] || sz[to[i]] > sz[son[p]])
son[p] = to[i];
}
}
void dfs2(int p)
{
static int tot = 0;
rk[id[p] = ++tot] = p;
if(son[p])
top[son[p]] = top[p],dfs2(son[p]);
for(register int i = first[p];i;i = pre[i])
if(!id[to[i]])
top[to[i]] = to[i],dfs2(to[i]);
}
struct node
{
long long val,sum,tag;
int ls,rs;
} seg[(N << 7) + 10];
int rt[N + 5];
int seg_tot;
void build(int &p,int tl,int tr)
{
p = ++seg_tot;
if(tl == tr)
{
seg[p].val = a[rk[tl]];
return ;
}
int mid = tl + tr >> 1;
build(seg[p].ls,tl,mid),build(seg[p].rs,mid + 1,tr);
seg[p].val = seg[seg[p].ls].val + seg[seg[p].rs].val;
}
void update(int l,int r,int &p,int tl,int tr)
{
seg[++seg_tot] = seg[p],p = seg_tot;
if(l <= tl && tr <= r)
{
seg[p].sum += seg[p].val,++seg[p].tag;
return ;
}
int mid = tl + tr >> 1;
if(l <= mid)
update(l,r,seg[p].ls,tl,mid);
if(r > mid)
update(l,r,seg[p].rs,mid + 1,tr);
seg[p].sum = seg[seg[p].ls].sum + seg[seg[p].rs].sum + seg[p].tag * seg[p].val;
}
pair<long long,long long> operator+(const pair<long long,long long> &a,const pair<long long,long long> &b)
{
return make_pair(a.first + b.first,a.second + b.second);
}
pair<long long,long long> query(int l,int r,int p,int tl,int tr)
{
if(!p || (l <= tl && tr <= r))
return make_pair(seg[p].sum,seg[p].val);
int mid = tl + tr >> 1;
pair<long long,long long> ret(0,0);
if(l <= mid)
ret = ret + query(l,r,seg[p].ls,tl,mid);
if(r > mid)
ret = ret + query(l,r,seg[p].rs,mid + 1,tr);
ret.first += ret.second * seg[p].tag;
return ret;
}
void dfs3(int p)
{
int x = P[p];
rt[p] = rt[fa[p]],sum[p] = sum[fa[p]] + dis[P[p]];
while(x)
update(id[top[x]],id[x],rt[p],1,n),x = fa[top[x]];
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
dfs3(to[i]);
}
long long query(int p,int x)
{
long long ret = sum[p] + dep[p] * dis[x];
while(x)
ret -= query(id[top[x]],id[x],rt[p],1,n).first * 2,x = fa[top[x]];
return ret;
}
int getlca(int x,int y)
{
while(top[x] ^ top[y])
dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]];
return dep[x] < dep[y] ? x : y;
}
int main()
{
scanf("%d%d%d",&type,&n,&q);
int u,v,w;
for(register int i = 1;i < n;++i)
scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
for(register int i = 1;i <= n;++i)
scanf("%d",P + i);
dep[1] = 1,top[1] = 1,dfs1(1),dfs2(1),build(rt[0],1,n),dfs3(1);
long long x,y,k;
while(q--)
{
scanf("%lld%lld%lld",&x,&y,&k),type ? (x ^= lastans,y ^= lastans,k ^= lastans) : 0;
int lca = getlca(x,y);
printf("%lld\n",lastans = query(x,k) + query(y,k) - query(lca,k) - query(fa[lca],k));
}
}