JZOJ 5018 合并奶牛

虽然明知有简洁明了的卡特兰数 DP 做法,还是写了题解的 DP 套 DP(

考虑给定最终序列 \(c\) 如何判定其是否合法:设 \(g_{i,j}\) 表示检查了最终序列的前 \(i\) 位,第一队的前 \(j\) 位,目前是否合法。
那么有 \(g_{i,j} = (c_i = a_j \land g_{i-1,j-1}) \lor (c_i = b_{i-j} \land g_{i-1,j})\),其中 \(a,b\) 分别表示第一、二队。
考虑利用这个 DP 的思路,套一个 DP 来计数。

由于两队内部互不相同,所以 \(g_{i.j} = 1 \Rightarrow c_i = a_j \lor c_i = b_{i-j}\),于是若确定了最终序列的第 \(i\) 位,对应的 \(g_i\) 最多有两个为 \(1\)
于是设 \(f_{i,j,k=0/1/2}\) 表示填了 \(c\) 的前 \(i\) 位,\(c_i = j\)\(g_i\) 中两个值第一个为 \(1\) / 第二个为 \(1\) / 两个均为 \(1\)
然后转移……整一些很恶心的分类讨论(

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <cstdio>
#include <cstring>
#define add(x,y) (x + y >= mod ? x + y - mod : x + y)
#define dec(x,y) (x < y ? x - y + mod : x - y)
using namespace std;
const int N = 2e3;
const int mod = 998244353;
int n,a[N + 5],b[N + 5];
int posa[N + 5],posb[N + 5];
int vis[N + 5][3];
int f[(N << 1) + 5][N + 5][3],t[N + 5],T,ans;
void update(int i,int j,int k,int v)
{
if(t[j] ^ T)
t[j] = T,memset(vis[j],0,sizeof vis[j]);
if(vis[j][k])
return ;
if(k < 2 && vis[j][k ^ 1])
f[i][j][k ^ 1] = dec(f[i][j][k ^ 1],v),f[i][j][2] = add(f[i][j][2],v),
vis[j][0] = vis[j][1] = vis[j][2] = 1;
else
f[i][j][k] = add(f[i][j][k],v),vis[j][k] = 1;
}
int main()
{
freopen("merge.in","r",stdin),freopen("merge.out","w",stdout);
scanf("%d",&n);
for(register int i = 1;i <= n;++i)
scanf("%d",a + i),posa[a[i]] = i;
for(register int i = 1;i <= n;++i)
scanf("%d",b + i),posb[b[i]] = i;
if(a[1] == b[1])
f[1][a[1]][2] = 1;
else
f[1][a[1]][0] = f[1][b[1]][1] = 1;
for(register int i = 1;i < 2 * n;++i)
for(register int j = 1;j <= n;++j)
{
if(f[i][j][0])
{
int x = posa[j],y = i - x;
if(b[y + 1] == j)
f[i + 1][a[x + 1]][0] = add(f[i + 1][a[x + 1]][0],f[i][j][0]),
f[i + 1][b[y + 1]][2] = add(f[i + 1][b[y + 1]][2],f[i][j][0]);
else if(a[x + 1] == b[y + 1])
f[i + 1][a[x + 1]][2] = add(f[i + 1][a[x + 1]][2],f[i][j][0]);
else
f[i + 1][a[x + 1]][0] = add(f[i + 1][a[x + 1]][0],f[i][j][0]),
f[i + 1][b[y + 1]][1] = add(f[i + 1][b[y + 1]][1],f[i][j][0]);
}
if(f[i][j][1])
{
int y = posb[j],x = i - y;
if(a[x + 1] == j)
f[i + 1][a[x + 1]][0] = add(f[i + 1][a[x + 1]][0],f[i][j][1]),
f[i + 1][a[x + 1]][2] = add(f[i + 1][a[x + 1]][2],f[i][j][1]);
else if(a[x + 1] == b[y + 1])
f[i + 1][b[y + 1]][2] = add(f[i + 1][b[y + 1]][2],f[i][j][1]);
else
f[i + 1][a[x + 1]][0] = add(f[i + 1][a[x + 1]][0],f[i][j][1]),
f[i + 1][b[y + 1]][1] = add(f[i + 1][b[y + 1]][1],f[i][j][1]);
}
if(f[i][j][2])
{
++T;
int x1 = posa[j],y1 = i - x1;
int y2 = posb[j],x2 = i - y2;
if(a[x1 + 1] == b[y1])
update(i + 1,a[x1 + 1],2,f[i][j][2]);
else
update(i + 1,a[x1 + 1],0,f[i][j][2]);
if(a[x1] == b[y1 + 1])
update(i + 1,b[y1 + 1],2,f[i][j][2]);
else
update(i + 1,b[y1 + 1],1,f[i][j][2]);
if(a[x2 + 1] == b[y2])
update(i + 1,a[x2 + 1],2,f[i][j][2]);
else
update(i + 1,a[x2 + 1],0,f[i][j][2]);
if(a[x2] == b[y2 + 1])
update(i + 1,b[y2 + 1],2,f[i][j][2]);
else
update(i + 1,b[y2 + 1],1,f[i][j][2]);
}
}
for(register int i = 0;i < 3;++i)
if(a[n] == b[n])
ans = add(ans,f[2 * n][a[n]][i]);
else
ans = add(ans,f[2 * n][a[n]][i]),
ans = add(ans,f[2 * n][b[n]][i]);
printf("%d\n",ans);
}