洛谷 5220 特工的信息流

这题虽然原 idea 是我的,但是 noname 改了之后我就一直咕咕咕没做了。
今天闲得发慌来练手速……

写完这题的第一感觉是可以回去把「『SDOI2011』颜色」的坑给填了(虽然我还是没打算写

其实做法比较显然,先考虑在序列上的做法,线段树维护区间后缀积之和区间积,那么合并左右子树的时候,根据乘法分配律可得:

\[\sum_{i=l}^r\prod_{j=i}^r a_j = \prod_{i=m+1}^r a_i\left(\sum_{i=l}^m\prod_{j=i}^m a_j\right) + \sum\limits_{i=m+1}^r\prod\limits_{j=i}^r a_j (m \in [l,r))\]

于是就显然。

考虑把这个做法放到树上,但是这个时候我们发现如果把路径从 LCA 划分成两段的话,有一段需要用与答案相反前缀积之和来统计,于是改一改线段树就好了。

关于此题树剖做法的码量瓶颈,我认为应该在于查询的过程……
必须保证思路清晰才能不写错。

然后是取模的问题,虽然模数很小但是也印证了那句话:

不开 long long 见祖宗,十年 OI 一场空。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#include <cstdio>
#include <algorithm>
#include <vector>
#include <utility>
#define ls (p << 1)
#define rs (ls | 1)
using namespace std;
const int N = 1e5;
const long long mod = 20924;
int n,m;
long long a[N + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v;
pre[tot] = first[u];
first[u] = tot;
}
int fa[N + 5],dep[N + 5],sz[N + 5],son[N + 5],top[N + 5],id[N + 5],rk[N + 5];
void dfs1(int p)
{
sz[p] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dep[to[i]] = dep[p] + 1;
dfs1(to[i]),sz[p] += sz[to[i]];
if(!son[p] || sz[to[i]] > sz[son[p]])
son[p] = to[i];
}
}
void dfs2(int p)
{
static int tot = 0;
rk[id[p] = ++tot] = p;
if(!son[p])
return ;
top[son[p]] = top[p],dfs2(son[p]);
for(register int i = first[p];i;i = pre[i])
if(!id[to[i]])
top[to[i]] = to[i],dfs2(to[i]);
}
struct segnode
{
long long prod,sufsum,presum;
} seg[(N << 2) + 10];
void build(int p,int tl,int tr)
{
if(tl == tr)
{
seg[p].prod = seg[p].sufsum = seg[p].presum = a[rk[tl]];
return ;
}
int mid = tl + tr >> 1;
build(ls,tl,mid);
build(rs,mid + 1,tr);
seg[p].prod = seg[ls].prod * seg[rs].prod % mod;
seg[p].sufsum = (seg[ls].sufsum * seg[rs].prod % mod + seg[rs].sufsum) % mod;
seg[p].presum = (seg[rs].presum * seg[ls].prod % mod + seg[ls].presum) % mod;
}
void modify(int x,int k,int p,int tl,int tr)
{
if(tl == tr)
{
seg[p].prod += k,seg[p].sufsum += k,seg[p].presum += k;
return ;
}
int mid = tl + tr >> 1;
if(x <= mid)
modify(x,k,ls,tl,mid);
else
modify(x,k,rs,mid + 1,tr);
seg[p].prod = seg[ls].prod * seg[rs].prod % mod;
seg[p].sufsum = (seg[ls].sufsum * seg[rs].prod % mod + seg[rs].sufsum) % mod;
seg[p].presum = (seg[rs].presum * seg[ls].prod % mod + seg[ls].presum) % mod;
}
long long query_prod(int l,int r,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
return seg[p].prod;
int mid = tl + tr >> 1;
long long ret = 1;
if(l <= mid)
ret = ret * query_prod(l,r,ls,tl,mid) % mod;
if(r > mid)
ret = ret * query_prod(l,r,rs,mid + 1,tr) % mod;
return ret;
}
long long query_sufsum(int l,int r,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
return seg[p].sufsum;
int mid = tl + tr >> 1;
if(l <= mid && r > mid)
return (query_sufsum(l,r,ls,tl,mid) * query_prod(l,r,rs,mid + 1,tr) % mod + query_sufsum(l,r,rs,mid + 1,tr)) % mod;
if(l <= mid)
return query_sufsum(l,r,ls,tl,mid);
else
return query_sufsum(l,r,rs,mid + 1,tr);
}
long long query_presum(int l,int r,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
return seg[p].presum;
int mid = tl + tr >> 1;
if(l <= mid && r > mid)
return (query_presum(l,r,rs,mid + 1,tr) * query_prod(l,r,ls,tl,mid) % mod + query_presum(l,r,ls,tl,mid)) % mod;
if(l <= mid)
return query_presum(l,r,ls,tl,mid);
if(r > mid)
return query_presum(l,r,rs,mid + 1,tr);
}
pair<int,int> getlca(int x,int y)
{
while(top[x] ^ top[y])
dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]];
return dep[x] < dep[y] ? make_pair(x,0) : make_pair(y,1);
}
long long query(int x,int y)
{
pair<int,int> t = getlca(x,y);
int lca = t.first,w = t.second;
vector< pair<int,int> > range;
while(top[x] ^ top[lca])
range.push_back(make_pair(id[top[x]],id[x])),x = fa[top[x]];
if(w)
range.push_back(make_pair(id[lca],id[x]));
long long temp = 0,prod = 1;
for(register int i = range.size() - 1;~i;--i)
temp = (temp + query_presum(range[i].first,range[i].second,1,1,n) * prod % mod) % mod,prod = prod * query_prod(range[i].first,range[i].second,1,1,n) % mod;
range.clear();
while(top[y] ^ top[lca])
range.push_back(make_pair(id[top[y]],id[y])),y = fa[top[y]];
if(!w)
range.push_back(make_pair(id[lca],id[y]));
long long ret = 0;
prod = 1;
for(register int i = 0;i < range.size();++i)
ret = (ret + query_sufsum(range[i].first,range[i].second,1,1,n) * prod % mod) % mod,prod = prod * query_prod(range[i].first,range[i].second,1,1,n) % mod;
return (ret + temp * prod % mod) % mod;
}
int main()
{
scanf("%d%d",&n,&m);
for(register int i = 1;i <= n;++i)
scanf("%lld",a + i);
int u,v;
for(register int i = 1;i < n;++i)
scanf("%d%d",&u,&v),add(u,v),add(v,u);
dep[1] = 1,dfs1(1),top[1] = 1,dfs2(1);
build(1,1,n);
char op;
int x,y;
while(m--)
{
scanf(" %c%d%d",&op,&x,&y);
if(op == 'Q')
printf("%lld\n",(query(x,y) + mod) % mod);
else
modify(id[x],y,1,1,n);
}
}