JZOJ 5662 尺树寸泓

文艺出题人啊……

第一眼看起来是个动态树,但是出题人给出了一个更简单的做法,仅需要线段树就可以解决。

首先我们从平衡树的定义开始分析:
众所周知平衡树是一种特殊的 BST,而用平衡树维护序列时,它的中序遍历就是这个序列。
以及,平衡树的旋转是维护平衡的一个策略,显然它的中序遍历不会改变。

于是又发现一个性质:中序遍历时,一棵子树是连续的。

所以我们用前缀和维护中序遍历的区间和,用线段树维护中序遍历的区间积。
旋转只会改变两棵子树的区间,可以在常数复杂度内维护。

Orz 了 LZW 大佬的代码发现可以直接用结点的中序遍历减去左儿子的 size 得到区间左端点,右端点类似。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <cstdio>
#include <algorithm>
#define ls(p) tree[p].lson
#define rs(p) tree[p].rson
using namespace std;
const int N = 2e5;
const long long mod = 1e9 + 7;
int n,m;
int st[N + 5],ed[N + 5],id[N + 5];
long long a[N + 5],sum[N + 5];
long long seg[(N << 2) + 5];
void change(int x,long long k,int p,int tl,int tr)
{
if(tl == tr)
{
seg[p] = k % mod;
return ;
}
int mid = tl + tr >> 1;
if(x <= mid)
change(x,k,p << 1,tl,mid);
else
change(x,k,p << 1 | 1,mid + 1,tr);
seg[p] = seg[p << 1] * seg[p << 1 | 1] % mod;
}
long long query(int l,int r,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
return seg[p];
int mid = tl + tr >> 1;
long long ret = 1;
if(l <= mid)
ret = ret * query(l,r,p << 1,tl,mid) % mod;
if(r > mid)
ret = ret * query(l,r,p << 1 | 1,mid + 1,tr) % mod;
return ret;
}
struct node
{
int lson,rson,fa;
} tree[N + 5];
void dfs(int p,int fa)
{
static int tot = 0;
if(!p)
return ;
tree[p].fa = fa;
dfs(ls(p),p);
sum[id[p] = ++tot] = a[p];
dfs(rs(p),p);
if(!ls(p))
st[p] = id[p];
else
st[p] = st[ls(p)];
if(!rs(p))
ed[p] = id[p];
else
ed[p] = ed[rs(p)];
}
int main()
{
freopen("splay.in","r",stdin);
freopen("splay.out","w",stdout);
scanf("%d%d",&n,&m);
for(register int i = 1;i <= n;++i)
scanf("%lld%d%d",a + i,&ls(i),&rs(i));
dfs(1,0);
for(register int i = 1;i <= n;++i)
sum[i] += sum[i - 1];
for(register int i = 1;i <= n;++i)
change(id[i],sum[ed[i]] - sum[st[i] - 1],1,1,n);
int op,x;
while(m--)
{
scanf("%d%d",&op,&x);
if(op == 0)
{
if(!ls(x))
continue ;
int t = ls(x);
ls(x) = rs(t);
rs(t) = x;
if(x == ls(tree[x].fa))
ls(tree[x].fa) = t;
else if(x == rs(tree[x].fa))
rs(tree[x].fa) = t;
tree[t].fa = tree[x].fa;
tree[x].fa = t;
tree[ls(x)].fa = x;
if(ls(x))
st[x] = st[ls(x)];
else
st[x] = id[x];
if(rs(t))
ed[t] = ed[rs(t)];
else
ed[t] = id[t];
change(id[x],sum[ed[x]] - sum[st[x] - 1],1,1,n);
change(id[t],sum[ed[t]] - sum[st[t] - 1],1,1,n);
}
else if(op == 1)
{
if(!rs(x))
continue;
int t = rs(x);
rs(x) = ls(t);
ls(t) = x;
if(x == ls(tree[x].fa))
ls(tree[x].fa) = t;
else if(x == rs(tree[x].fa))
rs(tree[x].fa) = t;
tree[t].fa = tree[x].fa;
tree[x].fa = t;
tree[rs(x)].fa = x;
if(rs(x))
ed[x] = ed[rs(x)];
else
ed[x] = id[x];
if(ls(t))
st[t] = st[ls(t)];
else
st[t] = id[t];
change(id[x],sum[ed[x]] - sum[st[x] - 1],1,1,n);
change(id[t],sum[ed[t]] - sum[st[t] - 1],1,1,n);
}
else
printf("%lld\n",query(st[x],ed[x],1,1,n));
}
}