洛谷 3603 雪辉

我想这个大概可以算是某种树上分块吧(

看到这种题,看到数据范围,就要优先考虑 bitset(
那么也就是说需要求出 \(q\) 条链的 bitset。

链上问题优先考虑树剖,但是看起来树剖复杂度是假的(
那么,类比序列上的分块,就有了一个想法:
选取 \(O(\sqrt n)\) 个关键点,预处理出两两关键点之间的 bitset,然后每次从两边往 LCA 跳并利用预处理的 bitset 优化这一段的复杂度。

那么关键点怎么选呢?容易想到贪心地按深度从大到小考虑每个点,以其子树内每个点到最近的关键点的距离来决定这个点是否为关键点。
然而这样貌似很麻烦,我也不想写个 DS 题还得打个如此复杂的树形 DP(
所以,可以直接随机选取关键点,复杂度也是对的(

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <random>
#define lowbit(x) ((x) & -(x))
using namespace std;

const int BUFF_SIZE = 1 << 20;
char BUFF[BUFF_SIZE],*BB,*BE;
#define gc() (BB == BE ? (BE = (BB = BUFF) + fread(BUFF,1,BUFF_SIZE,stdin),BB == BE ? EOF : *BB++) : *BB++)
template<class T>
inline void read(T &x)
{
x = 0;
char ch = 0,w = 0;
for(;ch < '0' || ch > '9';w |= ch == '-',ch = gc());
for(;ch >= '0' && ch <= '9';x = (x << 3) + (x << 1) + (ch ^ '0'),ch = gc());
w && (x = -x);
}

const int N = 1e5;
const int C = 3e4;
const int CNT = 320;
mt19937 rnd(20070921);
int n,m,f;
int a[N + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
int vis[N + 5],key[CNT + 5],up[CNT + 5];
struct bitset
{
unsigned long long a[(C >> 6) + 5];
int len;
inline bitset()
{
memset(a,0,sizeof a),len = -1;
}
inline void clear()
{
memset(a,0,sizeof a),len = -1;
}
inline void operator|=(const bitset &o)
{
len = max(len,o.len);
for(register int i = 0;i <= len;++i)
a[i] |= o.a[i];
}
inline void set(int x)
{
len = max(len,x >> 6),a[x >> 6] |= 1ULL << (x & 63);
}
inline int cnt()
{
int ret = 0;
for(register int i = 0;i <= len;++i)
ret += __builtin_popcountll(a[i]);
return ret;
}
inline int mex()
{
for(register int i = 0;i <= len;++i)
if(a[i] ^ (~0ULL))
for(register int j = 0;j < 64;++j)
if(!(a[i] & (1ULL << j)))
return (i << 6) + j;
}
} ans,s[CNT + 5][CNT + 5];
int fa[N + 5],dep[N + 5],sz[N + 5],son[N + 5],top[N + 5];
int q[N + 5],head,tail;
void bfs()
{
dep[q[++tail] = 1] = top[1] = 1;
for(register int p;head < tail;)
{
sz[p = q[++head]] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
fa[to[i]] = p,dep[to[i]] = dep[p] + 1,q[++tail] = to[i];
}
for(register int i = n;i;--i)
{
sz[fa[q[i]]] += sz[q[i]];
if(!son[fa[q[i]]] || sz[q[i]] > sz[son[fa[q[i]]]])
son[fa[q[i]]] = q[i];
}
for(register int i = 1;i <= n;++i)
top[q[i]] = q[i] == son[fa[q[i]]] ? top[fa[q[i]]] : q[i];
}
int getlca(int x,int y)
{
while(top[x] ^ top[y])
dep[top[x]] > dep[top[y]] ? (x = fa[top[x]]) : (y = fa[top[y]]);
return dep[x] < dep[y] ? x : y;
}
int lastans;
int main()
{
read(n),read(m),read(f);
for(register int i = 1;i <= n;++i)
read(a[i]);
int u,v;
for(register int i = 1;i < n;++i)
read(u),read(v),add(u,v),add(v,u);
bfs();
for(register int i = 1;i <= min(n,CNT);++i)
{
for(key[i] = rnd() % n + 1;vis[key[i]];key[i] = rnd() % n + 1);
vis[key[i]] = i;
}
for(register int i = 1,p;i <= min(n,CNT);++i)
{
s[i][i].set(a[p = key[i]]),ans.clear();
for(;p;p = fa[p])
{
ans.set(a[p]);
if(vis[p] && (p ^ key[i]))
s[i][vis[p]] = ans,!up[i] && (up[i] = p);
}
}
for(int c,x,y,ans1,ans2;m;--m)
{
read(c),ans.clear();
for(register int lca,l;c;--c)
{
read(x),read(y),f && (x ^= lastans,y ^= lastans);
ans.set(a[lca = getlca(x,y)]);

for(;!vis[x] && (x ^ lca);ans.set(a[x]),x = fa[x]);
if(x ^ lca)
{
for(l = x;dep[up[vis[l]]] > dep[lca];l = up[vis[l]]);
ans |= s[vis[x]][vis[l]];
for(l = fa[l];l ^ lca;ans.set(a[l]),l = fa[l]);
}

for(;!vis[y] && (y ^ lca);ans.set(a[y]),y = fa[y]);
if(y ^ lca)
{
for(l = y;dep[up[vis[l]]] > dep[lca];l = up[vis[l]]);
ans |= s[vis[y]][vis[l]];
for(l = fa[l];l ^ lca;ans.set(a[l]),l = fa[l]);
}
}
ans1 = ans.cnt(),ans2 = ans.mex(),lastans = ans1 + ans2;
printf("%d %d\n",ans1,ans2);
}
}