JZOJ 4370.Hypocritical

假设题目中给出的是一棵 Trie,可以直接建广义 SAM,每次插入字符的时候把对应的 Hypocritical 值的贡献加入到 SAM 上的结点的 DP 数组中。

如果不是 Trie,容易发现可以合并成 Trie,因为答案只和结尾处的 Hypocritical 相关,所以并不会影响。
(实际上也可以不合并成 Trie)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <cstdio>
#include <cstring>
using namespace std;
const int N = 1e5 + 1;
const int K = 16;
const int mod = 998244353;
int n,m;
int a[N + 5],ans[N + 5][K + 5];
char s[N + 5];
struct node
{
int ch[5];
int ed;
} tr[N + 5];
int f[N + 5][K + 5],tot = 1;
inline void merge(int *f,int *g)
{
int h[K + 5];
memset(h,0,sizeof h);
for(register int i = 0;i <= K;++i)
for(register int j = 0;j <= i;++j)
h[i] = (h[i] + (long long)f[i - j] * g[j]) % mod;
memcpy(f,h,sizeof h);
}
namespace SAM
{
struct node
{
int ch[5];
int fa,len;
} sam[(N << 1) + 5];
int tot = 1,las = 1;
int c[N + 5],a[(N << 1) + 5];
int f[(N << 1) + 5][K + 5];
inline void insert(int x,int *g)
{
int cur = las,p = ++tot;
sam[p].len = sam[cur].len + 1;
for(;cur && !sam[cur].ch[x];cur = sam[cur].fa)
sam[cur].ch[x] = p;
if(!cur)
sam[p].fa = 1;
else
{
int q = sam[cur].ch[x];
if(sam[q].len == sam[cur].len + 1)
sam[p].fa = q;
else
{
int nxt = ++tot;
sam[nxt] = sam[q],sam[nxt].len = sam[cur].len + 1,sam[p].fa = sam[q].fa = nxt;
for(;cur && sam[cur].ch[x] == q;cur = sam[cur].fa)
sam[cur].ch[x] = nxt;
}
}
merge(f[las = p],g);
}
void build()
{
for(register int i = 1;i <= tot;++i)
++c[sam[i].len];
for(register int i = 1;i <= n;++i)
c[i] += c[i - 1];
for(register int i = tot;i > 1;--i)
a[c[sam[i].len]--] = i;
for(register int i = tot,u,v;i > 1;--i)
{
u = sam[a[i]].fa,v = a[i];
for(register int i = 0;i <= K;++i)
ans[sam[v].len][i] = (ans[sam[v].len][i] + f[v][i]) % mod,
ans[sam[u].len][i] = (ans[sam[u].len][i] - f[v][i] + mod) % mod;
merge(f[u],f[v]);
}
for(register int i = 0;i <= K;++i)
for(register int j = n;j;--j)
ans[j][i] = (ans[j][i] + ans[j + 1][i]) % mod;
}
}
int q[N + 5],head,tail;
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
int fa[N + 5],rt[N + 5];
void dfs(int p)
{
for(register int i = K;i;--i)
f[rt[p]][i] = (f[rt[p]][i] + (long long)f[rt[p]][i - 1] * a[p]) % mod;
for(register int i = first[p],x;i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,x = s[to[i]] - 'a';
!tr[rt[p]].ch[x] && (tr[rt[p]].ch[x] = ++tot);
rt[to[i]] = tr[rt[p]].ch[x];
dfs(to[i]);
}
}
int main()
{
for(register int i = 1;i <= N;++i)
f[i][0] = 1;
for(register int i = 1;i <= (N << 1);++i)
SAM::f[i][0] = 1;
scanf("%d%d%*d%s",&n,&m,s + 2),++n;
for(register int i = 2;i <= n;++i)
scanf("%d",a + i),a[i] %= mod;
add(1,2);
int u,v;
for(register int i = 3;i <= n;++i)
scanf("%d%d",&u,&v),++u,++v,add(u,v),add(v,u);
rt[1] = 1,dfs(1),tr[q[++tail] = 1].ed = 1;
for(register int p;head < tail;)
{
p = q[++head];
for(register int i = 0;i < 5;++i)
if(tr[p].ch[i])
SAM::las = tr[p].ed,SAM::insert(i,f[tr[p].ch[i]]),tr[tr[p].ch[i]].ed = SAM::las,q[++tail] = tr[p].ch[i];
}
SAM::build();
int len,k;
for(;m;--m)
scanf("%d%d",&len,&k),printf("%d\n",ans[len][k]);
}