LibreOJ 3043.「ZJOI2019」线段树

ZJOI 神仙题!四年三道线段树!
据说毒瘤开题顺序 2 3 1,而且只有数据结构题可做……
身在 GD 的 ZJ 人已经被老家省选虐怕了……

容易发现线段树的棵数永远都是 2i2^i,并且每一棵线段树的结构是一模一样的。
并且复制操作实际上只需要考虑原来的贡献 + 新的贡献即可。

fi,pf_{i,p} 表示所有 2i2^i 棵线段树中的 pp 结点有多少个有标记。
考虑把线段树上所有结点按照一次修改操作的遍历顺序分成五类(没错就是在 Orz Sooke 的题解):

  1. 修改涉及了一部分区间的结点,即修改操作走过且接着往儿子中走的结点。
  2. 修改涉及了全部区间的结点,且本次被打了标记。
  3. 修改未涉及的结点,但在父亲结点被走过时被下推了标记。
  4. 修改涉及了全部区间的结点,但本次未被走到过。
  5. 其他修改未涉及的结点。

结点 pp 属于第 cc 类记作

然后先列出几个转移:

写到 的时候发现了点严肃的问题:
转移值与祖先中是否有标记有关。

来,直接再设一个 gi,pg_{i,p} 表示所有 2i2^i 棵线段树中的 pp 结点有多少个所有祖先都没有标记。

然后考虑怎么实现。
都有 O(logn)O(\log n) 个,于是直接暴力模拟修改操作即可。
看起来不可做,其实就是几个区间双倍的操作,打标记处理。
然后维护一下 ff 的总和,滚动 ii 维。
注意分清楚题目中的线段树和维护 DP 值的线段树……下推标记多多益善……

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <cstdio>
#define ls (p << 1)
#define rs (ls | 1)
using namespace std;
const int N = 1e5;
const int mod = 998244353;
int n,m,pw[N + 5];
struct node
{
int f,g;
int tf,tg;
int sum;
} seg[(N << 3) + 10];
inline void push(int p)
{
if(seg[p].tf ^ 1)
{
seg[ls].f = (long long)seg[ls].f * seg[p].tf % mod,seg[ls].tf = (long long)seg[ls].tf * seg[p].tf % mod;
seg[ls].sum = (long long)seg[ls].sum * seg[p].tf % mod;
seg[rs].f = (long long)seg[rs].f * seg[p].tf % mod,seg[rs].tf = (long long)seg[rs].tf * seg[p].tf % mod;
seg[rs].sum = (long long)seg[rs].sum * seg[p].tf % mod;
seg[p].tf = 1;
}
if(seg[p].tg ^ 1)
{
seg[ls].g = (long long)seg[ls].g * seg[p].tg % mod,seg[ls].tg = (long long)seg[ls].tg * seg[p].tg % mod;
seg[rs].g = (long long)seg[rs].g * seg[p].tg % mod,seg[rs].tg = (long long)seg[rs].tg * seg[p].tg % mod;
seg[p].tg = 1;
}
}
void build(int p,int tl,int tr)
{
seg[p].tf = seg[p].tg = seg[p].g = 1;
if(tl == tr)
return ;
int mid = tl + tr >> 1;
build(ls,tl,mid),build(rs,mid + 1,tr);
}
void update(int l,int r,int k,int p,int tl,int tr)
{
push(p);
if(l <= tl && tr <= r)
{
seg[p].f = (seg[p].f + pw[k - 1]) % mod;
seg[ls].f = 2LL * seg[ls].f % mod,seg[ls].tf = 2LL * seg[ls].tf % mod;
seg[ls].sum = 2LL * seg[ls].sum % mod;
seg[rs].f = 2LL * seg[rs].f % mod,seg[rs].tf = 2LL * seg[rs].tf % mod;
seg[rs].sum = 2LL * seg[rs].sum % mod;
seg[p].sum = (seg[p].f + (seg[ls].sum + seg[rs].sum) % mod) % mod;
return ;
}
seg[p].g = (seg[p].g + pw[k - 1]) % mod;
int mid = tl + tr >> 1;
if(l <= mid && r > mid)
update(l,r,k,ls,tl,mid),update(l,r,k,rs,mid + 1,tr);
else if(l <= mid)
{
seg[rs].f = (seg[rs].f + (pw[k - 1] - seg[rs].g + mod) % mod) % mod,seg[rs].g = 2LL * seg[rs].g % mod;
push(rs);
seg[rs << 1].f = 2LL * seg[rs << 1].f % mod,seg[rs << 1].tf = 2LL * seg[rs << 1].tf % mod;
seg[rs << 1].sum = 2LL * seg[rs << 1].sum % mod;
seg[rs << 1].g = 2LL * seg[rs << 1].g % mod,seg[rs << 1].tg = 2LL * seg[rs << 1].tg % mod;
seg[rs << 1 | 1].f = 2LL * seg[rs << 1 | 1].f % mod,seg[rs << 1 | 1].tf = 2LL * seg[rs << 1 | 1].tf % mod;
seg[rs << 1 | 1].sum = 2LL * seg[rs << 1 | 1].sum % mod;
seg[rs << 1 | 1].g = 2LL * seg[rs << 1 | 1].g % mod,seg[rs << 1 | 1].tg = 2LL * seg[rs << 1 | 1].tg % mod;
seg[rs].sum = (seg[rs].f + (seg[rs << 1].sum + seg[rs << 1 | 1].sum) % mod) % mod;
update(l,r,k,ls,tl,mid);
}
else
{
seg[ls].f = (seg[ls].f + (pw[k - 1] - seg[ls].g + mod) % mod) % mod,seg[ls].g = 2LL * seg[ls].g % mod;
push(ls);
seg[ls << 1].f = 2LL * seg[ls << 1].f % mod,seg[ls << 1].tf = 2LL * seg[ls << 1].tf % mod;
seg[ls << 1].sum = 2LL * seg[ls << 1].sum % mod;
seg[ls << 1].g = 2LL * seg[ls << 1].g % mod,seg[ls << 1].tg = 2LL * seg[ls << 1].tg % mod;
seg[ls << 1 | 1].f = 2LL * seg[ls << 1 | 1].f % mod,seg[ls << 1 | 1].tf = 2LL * seg[ls << 1 | 1].tf % mod;
seg[ls << 1 | 1].sum = 2LL * seg[ls << 1 | 1].sum % mod;
seg[ls << 1 | 1].g = 2LL * seg[ls << 1 | 1].g % mod,seg[ls << 1 | 1].tg = 2LL * seg[ls << 1 | 1].tg % mod;
seg[ls].sum = (seg[ls].f + (seg[ls << 1].sum + seg[ls << 1 | 1].sum) % mod) % mod;
update(l,r,k,rs,mid + 1,tr);
}
seg[p].sum = (seg[p].f + (seg[ls].sum + seg[rs].sum) % mod) % mod;
}
int main()
{
scanf("%d%d",&n,&m),build(1,1,n),pw[0] = 1;
for(register int i = 1;i <= m;++i)
pw[i] = 2LL * pw[i - 1] % mod;
int op,l,r;
for(register int cnt = 0;m;--m)
{
scanf("%d",&op);
if(op == 1)
scanf("%d%d",&l,&r),update(l,r,++cnt,1,1,n);
else
printf("%d\n",seg[1].sum);
}
}