LibreOJ 2473.「九省联考 2018」秘密袭击

题目所求为每个连通块的第 \(k\) 大权值之和,相当于枚举 \(w = 1\dots W\) 求满足第 \(k\) 大权值不小于 \(w\) 的连通块个数,又相当于枚举 \(w = 1\dots W\) 求满足不小于 \(w\) 的权值至少有 \(k\) 个的连通块个数。

枚举 \(w\),设 \(f_{u,w}(i)\) 表示在以 \(u\) 为根的子树内选择一个包含 \(u\) 的连通块,满足其中不小于 \(w\) 的权值有 \(i\) 个的方案数。
考虑合并 \(u\) 和儿子 \(v\) 的状态: \[ f'_{u,w}(i) = f_{u,w}(i) + \prod\limits_{j=0}^i f_{u,w}(j) f_{v,w}(i-j) \] (为了美观,假设 \(i<0\)\(f_{u,w}(i) = 0\)。)

\(F_u(x)\) 表示 \(f_u\) 的生成函数,则有 \[ F'_{u,w}(x) = F_{u,w}(x) (1 + F_{v,w}(x)) \]

考虑按照点值计算,则枚举 \(x = 1 \dots n+1\),利用线段树按照 \(w\) 维护整体 DP。
为了保证复杂度,同时需要维护 \(G_{u,w}(x)\) 表示 \(\sum\limits_{p \in {\rm subtree}(u)} F_{p,w}(x)\)

对于每个点,一开始时 \(F_{u,w}(x) = x^{[d_u \ge w]}\)
则考虑在线段树上维护变换 \((a,b,c,d)\) 表示 \((f,g) \to (af+b,g+cf+d)\)
则容易合并两个变换 \((a_1,b_1,c_1,d_1),(a_2,b_2,c_2,d_2)\)\[ (a_1a_2,b_1a_2+b_2,c_1+a_1c_2,b_1c_2+d_1+d_2) \]

此外,你需要一个 \(O(n^2)\) 的拉格朗日插值。
如果点值形如 \((i,y_i)\) 的话,可以考虑范德蒙德矩阵求逆。
不过可以首先暴力计算出 \(x - x_i\) 的乘积(注意这是一个 \(n+1\) 次多项式),然后每次模拟一下多项式除法即可计算 \(\prod\limits_{i\ne j} (x - x_j)\)

大常数选手不开 O2 过不去(

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include <cstdio>
#include <cstring>
using namespace std;
const int N = 1666;
const unsigned mod = 64123;
inline unsigned fpow(unsigned a,int b)
{
unsigned ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = ret * a % mod),a = a * a % mod;
return ret;
}
int n,k,w,x;
int a[N + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
struct Value
{
unsigned a,b,c,d;
inline Value(unsigned p = 1,unsigned q = 0,unsigned r = 0,unsigned s = 0)
{
a = p,b = q,c = r,d = s;
}
inline Value operator*(const Value &o) const
{
return Value(a * o.a % mod,(b * o.a + o.b) % mod,(c + a * o.c) % mod,(b * o.c + d + o.d) % mod);
}
};
namespace SEG
{
struct node
{
Value val;
int ls,rs;
} seg[(N << 7) + 5];
int seg_tot;
inline void push(int p)
{
int &tot = seg_tot;
!seg[p].ls && (seg[seg[p].ls = ++tot] = seg[0],1),
seg[seg[p].ls].val = seg[seg[p].ls].val * seg[p].val;
!seg[p].rs && (seg[seg[p].rs = ++tot] = seg[0],1),
seg[seg[p].rs].val = seg[seg[p].rs].val * seg[p].val;
seg[p].val = Value();
}
unsigned query(int p,int tl,int tr)
{
if(tl == tr)
return seg[p].val.d;
push(p);
int mid = tl + tr >> 1;
unsigned ret = 0;
ret = (ret + query(seg[p].ls,tl,mid)) % mod;
ret = (ret + query(seg[p].rs,mid + 1,tr)) % mod;
return ret;
}
void update(int l,int r,Value k,int &p,int tl,int tr)
{
int &tot = seg_tot;
!p && (seg[p = ++tot] = seg[0],1);
if(l <= tl && tr <= r)
{
seg[p].val = seg[p].val * k;
return ;
}
push(p);
int mid = tl + tr >> 1;
l <= mid && (update(l,r,k,seg[p].ls,tl,mid),1);
r > mid && (update(l,r,k,seg[p].rs,mid + 1,tr),1);
}
int merge(int p,int q)
{
if(!p || !q)
return p | q;
if(!seg[p].ls && !seg[p].rs)
{
seg[q].val = seg[q].val * Value(seg[p].val.b,seg[p].val.b,0,seg[p].val.d);
return q;
}
if(!seg[q].ls && !seg[q].rs)
{
seg[p].val = seg[p].val * Value((seg[q].val.b + 1) % mod,0,0,seg[q].val.d);
return p;
}
push(p),push(q);
seg[p].ls = merge(seg[p].ls,seg[q].ls),
seg[p].rs = merge(seg[p].rs,seg[q].rs);
return p;
}
}
int fa[N + 5];
int rt[N + 5];
unsigned ansp[N + 5],ansc[N + 5],ans;
unsigned fac[N + 5],ifac[N + 5],inv[N + 5];
unsigned f[N + 5];
void dfs(int p)
{
SEG::update(1,a[p],Value(1,x,0,0),rt[p],1,w),a[p] < w && (SEG::update(a[p] + 1,w,Value(1,1,0,0),rt[p],1,w),1);
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
fa[to[i]] = p,
dfs(to[i]),
rt[p] = SEG::merge(rt[p],rt[to[i]]);
SEG::seg[rt[p]].val = SEG::seg[rt[p]].val * Value(1,0,1,0);
}
int main()
{
scanf("%d%d%d",&n,&k,&w);
for(register int i = 1;i <= n;++i)
scanf("%d",a + i);
int u,v;
for(register int i = 2;i <= n;++i)
scanf("%d%d",&u,&v),add(u,v),add(v,u);
for(x = 1;x <= n + 1;++x)
memset(rt,0,sizeof rt),SEG::seg_tot = 0,
dfs(1),
ansp[x] = SEG::query(rt[1],1,w);
fac[0] = 1;
for(register int i = 1;i <= n + 1;++i)
fac[i] = fac[i - 1] * i % mod;
ifac[n + 1] = fpow(fac[n + 1],mod - 2);
for(register int i = n + 1;i;--i)
ifac[i - 1] = ifac[i] * i % mod;
for(register int i = 1;i <= n + 1;++i)
inv[i] = ifac[i] * fac[i - 1] % mod;
f[0] = 1;
for(register int i = 1;i <= n + 1;++i)
{
for(register int j = n + 1;j;--j)
f[j] = (f[j] * (mod - i) + f[j - 1]) % mod;
f[0] = f[0] * (mod - i) % mod;
}
for(register int i = 1;i <= n + 1;++i)
{
f[0] = (mod - f[0] * inv[i] % mod) % mod;
for(register int j = 1;j <= n + 1;++j)
f[j] = (mod - (f[j] - f[j - 1] + mod) % mod * inv[i] % mod) % mod;
ansp[i] = ansp[i] * ifac[i - 1] % mod * ifac[n + 1 - i] % mod * ((n + 1 - i & 1) ? mod - 1 : 1) % mod;
for(register int j = 0;j <= n;++j)
ansc[j] = (ansc[j] + ansp[i] * f[j]) % mod;
for(register int j = n + 1;j;--j)
f[j] = (f[j] * (mod - i) + f[j - 1]) % mod;
f[0] = f[0] * (mod - i) % mod;
}
for(register int i = k;i <= n;++i)
ans = (ans + ansc[i]) % mod;
printf("%u\n",ans);
}