BZOJ 2152 聪聪可可

据说可以树形 DP \(O(n)\) 过去?
反正我喜欢点分治!

由于题目限制的仅是 \(3\) 的余数(当然开到 \(10^5\) 点分治也可做),所以我们可以对于距离除以 \(3\) 的余数 \(0,1,2\) 来统计。
即,设 \(cnt_{0/1/2}\) 表示距离膜 \(3\)\(0/1/2\) 的点的个数。
那么对于结点 \(p\),把答案加上 \(cnt_x\;(dis_p + x \equiv 0 \pmod 3,0 \le x < 3)\)
显然 \(x\) 就是 \(-dis_p \bmod 3\)

然后有一些细节,比如到当前分治的点的路径只会被统计一次,
所以我们干脆只统计 \(x < y\)\((x,y)\) 的个数,并在输出时乘 \(2\)\(n\)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 2e4;
int n;
int to[(N << 1) + 5],pre[(N << 1) + 5],val[(N << 1) + 5],first[N + 5];
inline void add(int u,int v,int w)
{
static int tot = 0;
to[++tot] = v;
val[tot] = w;
pre[tot] = first[u];
first[u] = tot;
}
int vis[N + 5],sum,cnt[N + 5];
long long dis[N + 5];
int sz[N + 5],max_part[N + 5],rt;
long long ans;
void get_rt(int p,int fa)
{
sz[p] = 1,max_part[p] = 0;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa && !vis[to[i]])
{
get_rt(to[i],p);
sz[p] += sz[to[i]];
max_part[p] = max(max_part[p],sz[to[i]]);
}
max_part[p] = max(max_part[p],sum - sz[p]);
if(max_part[p] < max_part[rt])
rt = p;
}
void get_dis(int p,int fa)
{
++cnt[dis[p] % 3];
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa && !vis[to[i]])
{
dis[to[i]] = dis[p] + val[i];
get_dis(to[i],p);
}
}
void clear(int p,int fa)
{
--cnt[dis[p] % 3];
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa && !vis[to[i]])
clear(to[i],p);
}
void calc(int p,int fa)
{
ans += cnt[(3 - dis[p] % 3) % 3];
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa && !vis[to[i]])
calc(to[i],p);
}
void solve(int p)
{
vis[p] = 1;
cnt[0] = cnt[1] = cnt[2] = 0;
dis[p] = 0;
get_dis(p,0);
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]])
clear(to[i],p),calc(to[i],p);
for(register int i = first[p];i;i = pre[i])
if(!vis[to[i]])
{
rt = 0,sum = sz[to[i]],get_rt(to[i],p);
solve(rt);
}
}
int main()
{
max_part[0] = 0x3f3f3f3f;
scanf("%d",&n);
int u,v,w;
for(register int i = 1;i < n;++i)
scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
rt = 0,sum = n,get_rt(1,0);
solve(rt);
ans += ans + n;
printf("%lld/%lld\n",ans / __gcd(ans,(long long)n * n),(long long)n * n / __gcd(ans,(long long)n * n));
}