JZOJ 6514 树上的数

这个题好神仙哦(

\(f(n)\) 表示点权乘积为 \(n\) 的所有方案的权值(答案)之和。
不难看出这个函数是个积性函数,因为权值互质意味着两棵树的所有点都两两互质,那么 \(\gcd\) 什么的就直接乘起来就可以。

既然要求是正奇数,那么只需要强制令偶数处的 \(f\) 值为 \(0\) 即可。
则注意到 \(f(p) = [p > 2]np\),其中 \(p\) 为质数。

考虑 Min_25 筛,则需要求出 \(f(p^k)\) 的值。
注意到若乘积为 \(p^k\) 的形式,那么每个点的权值显然也是 \(p^k\) 的形式。
又注意到由于不需要考虑 \(2\),所以 \(k \le \lfloor \log_3 10^{10} \rfloor = 20\)
于是可以全部 \(\log p\),设 \(\mathrm{cnt}_{x,y}\) 表示给这棵树上赋点权使得点权之和为 \(x\),所有路径上的点权 \(\min\) 之和为 \(y\) 的方案数。

首先考虑一下当 \(x\) 确定时 \(y\) 的上界。
\(g_{u,i} = [a_u \ge i]\),则 \(\sum\limits_{u=1}^n \sum\limits_{i=1}^x g_{u,i} = \sum\limits_{u=1}^n a_u = x\)
\(\mathrm{Path}\) 为所有路径的集合,则有 \[ \begin{aligned} y &= \sum\limits_{S \in \mathrm{Path}} \min\limits_{u \in S} a_u \\ &= \sum\limits_{S \in \mathrm{Path}} \sum\limits_{i=1}^x \prod\limits_{u \in S} g_{u,i} \\ &= \sum\limits_{i=1}^x \sum\limits_{S \in \mathrm{Path}} \prod\limits_{u \in S} g_{u,i} \end{aligned} \]

注意到对 \(0/1\) 进行乘法相当于取 \(\min\),所以可以通过只乘路径的起点和终点来进行放缩。
\[ \begin{aligned} y &\le \sum\limits_{i=1}^k \sum\limits_{u=1}^{n-1} \sum\limits_{v=u+1}^n g_{u,i} g_{v,i} \\ &= \frac12 \sum\limits_{i=1}^k \left(\sum\limits_{u=1}^n g_{u,i}\right)\left(1 + \sum\limits_{u=1}^n g_{u,i}\right) \\ &\le \frac12 \left(\sum\limits_{i=1}^k \sum\limits_{u=1}^n g_{u,i}\right)\left(1+\sum\limits_{i=1}^k \sum\limits_{u=1}^n g_{u,i}\right) \\ &\le \frac{x(x+1)}2 \end{aligned} \]

(DH 大佬的证法看得不是很懂,感觉自己只能从他的思路证出 \(y \le \frac{x(n+1)}2\)

于是考虑一个非常暴力的 DP,设 \(f_{u,i,v,S}\) 表示 \(u\) 的子树内权值和为 \(i\),所有路径上点权 \(\min\) 之和为 \(v\),所有点到 \(u\) 的路径上的 \(\min\) 值的可重集为 \(S\)
转移考虑枚举 \(a_u\),对 \(i\) 维做树形背包,另外枚举当前的 \(v,S\) 和儿子的 \(v,S\) 并合并状态。
这个东西显然是不能过的,复杂度我也算不出来(

注意到显然有很多无用状态,首先考虑有多少种 \(S\)
首先发现 \(S\) 中的 \(0\) 是无用的,可以忽略。
由于 \(S\) 的元素和显然不超过 \(i\),所以 \(S\) 最多有 \(\sum\limits_{j=1}^i \operatorname{partition}(i)\) 种,其中 \(\operatorname{partition}(n)\) 表示 \(n\) 的划分数。
\(i=20\) 时这个值为 \(2714\),可以接受。

然后考虑将 \(v,S\) 视作一个二元组,在数据随机的情况下,且确定了 \(u,i\) 的情况下,大概有 \(17\,000\) 种不同的这样的二元组,依然可以接受。

最后一个问题在于如何离散化状态,以及加速状态合并。
关于离散化状态,首先先通过状态压缩 \(S\)\(S\) 离散化。
(状压可重集的方式是同一种元素用若干个连续的 \(1\) 表示这个元素的个数,不同种类元素用 \(0\) 隔开)
然后直接开一个数组来存储 \((v,S)\) 的对应关系。

加速状态合并的话,考虑合并 \((v_0,S_0),(v_1,S_1)\)
\[\left(v_0+v_1+\sum\limits_{x \in S_0} \sum\limits_{y \in S_1} \min\{x,y\},S_0 \cup \left\{\min\{y,w\}\mid y \in S_1\right\}\right)\]

这些东西都预处理一下就可以了。

细节超多(
请配合 O2 使用(

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#pragma GCC optimize(2)
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#define add(a,b) (a + b >= mod ? a + b - mod : a + b)
#define dec(a,b) (a < b ? a - b + mod : a - b)
using namespace std;
const long long N = 1e10;
const int MX = 1e5;
const int M = 100;
const int K = 20;
const int V = 210;
const int S = 2714;
const int W = 17000;
const int mod = 998244353;
const int inv = 499122177;
long long n;
int m,k;
namespace DP
{
int to[(M << 1) + 5],pre[(M << 1) + 5],first[M + 5];
inline void add_edge(int u,int v)
{
static int tot = 0;
to[++tot] = v,pre[tot] = first[u],first[u] = tot;
}
struct Cnt
{
int a[K + 5];
inline void clear()
{
memset(a,0,sizeof a);
}
inline int &operator[](const int &x)
{
return a[x];
}
inline const int &operator[](const int &x) const
{
return a[x];
}
inline bool check() const
{
int sum = 0;
for(register int i = 1;i <= k;++i)
sum += a[i] * i;
return sum <= k;
}
} cur,s[S + 5];
int pl[S + 5][S + 5],mn[S + 5][K + 5];
int mer[S + 5][S + 5];
int mem[(1 << K + 1) + 5];
int sum,tot;
int &id(const Cnt &a)
{
int s = 0;
for(register int i = k;s <<= 1,i;--i)
for(register int j = 1;j <= a[i];++j)
s = (s << 1) | 1;
return mem[s];
}
void dfs_pre(int x)
{
if(x > k)
{
s[id(cur) = ++tot] = cur;
return ;
}
for(register int i = 0;sum + i * x <= k;++i)
sum += i * x,cur[x] = i,dfs_pre(x + 1),sum -= i * x;
}
int f[M + 5][K + 5][W + 5],g[K + 5][W + 5],temp[K + 5][W + 5];
int vs[M + 5][K + 5][W + 5][2];
int rk[K + 5][V + 5][S + 5];
int fa[M + 5];
int cnt[K + 5][V + 5],up[M + 5][K + 5];
inline int get(int p,int j,int v,int s)
{
return rk[j][v][s] ? rk[j][v][s] : (vs[p][j][rk[j][v][s] = ++up[p][j]][0] = v,vs[p][j][up[p][j]][1] = s,up[p][j]);
}
void dfs(int p)
{
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
fa[to[i]] = p,dfs(to[i]);
memset(rk,0,sizeof rk);
for(register int w = 0,a,b;w <= k;++w)
{
memset(g,0,sizeof g),cur.clear(),++cur[w],++g[w][get(p,w,w,id(cur))];
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
memset(temp,0,sizeof temp);
for(register int u = 0;u <= k;++u)
for(register int x = 1,r = up[p][u];x <= r;++x)
if(g[u][x])
for(register int v = 0;u + v <= k;++v)
for(register int y = 1;y <= up[to[i]][v];++y)
if(f[to[i]][v][y])
{
a = vs[p][u][x][0] + vs[to[i]][v][y][0] + mer[vs[p][u][x][1]][vs[to[i]][v][y][1]],b = pl[vs[p][u][x][1]][mn[vs[to[i]][v][y][1]][w]];
if(!b)
continue;
int &ans = temp[u + v][get(p,u + v,a,b)];
ans = (ans + (long long)g[u][x] * f[to[i]][v][y]) % mod;
}
memcpy(g,temp,sizeof g);
}
for(register int i = 0;i <= k;++i)
for(register int j = 1;j <= up[p][i];++j)
f[p][i][j] = add(f[p][i][j],g[i][j]);
}
}
void main()
{
scanf("%d%lld",&m,&n),k = log(n) / log(3);
dfs_pre(1);
for(register int i = 1;i <= tot;++i)
{
for(register int x = 0;x <= k;++x)
{
cur.clear();
for(register int u = 1;u <= k;++u)
cur[min(u,x)] += s[i][u];
cur.check() && (mn[i][x] = id(cur),1);
}
for(register int j = 1;j <= tot;++j)
{
cur.clear();
for(register int u = 1;u <= k;++u)
cur[u] = s[i][u] + s[j][u];
cur.check() && (pl[i][j] = id(cur),1);
cur.clear();
for(register int u = 1;u <= k;++u)
cur[u] = cur[u - 1] + u * s[j][u];
for(register int u = 1;u <= k;++u)
mer[i][j] += s[i][u] * cur[u],
mer[j][i] += s[i][u] * cur[u - 1];
}
}
int u,v;
for(register int i = 1;i < m;++i)
scanf("%d%d",&u,&v),add_edge(u,v),add_edge(v,u);
tot = 0,dfs(1);
for(register int i = 0;i <= k;++i)
for(register int j = 1;j <= up[1][i];++j)
cnt[i][vs[1][i][j][0]] = add(cnt[i][vs[1][i][j][0]],f[1][i][j]);
}
}
namespace Min_25
{
int lim;
int vis[MX + 5],cnt,prime[MX + 5],Gprime[MX + 5];
int tot,le[MX + 5],ge[MX + 5];
long long lis[2 * MX + 5];
int G[2 * MX + 5],Fprime[2 * MX + 5];
int w[MX + 5][K + 5];
inline int &id(long long x)
{
return x <= lim ? le[x] : ge[n / x];
}
int F(int k,long long n)
{
if(n < prime[k] || n <= 1)
return 0;
int ret = (Fprime[id(n)] - (long long)m * Gprime[k - 1] % mod + mod) % mod;
if(k == 1)
ret = dec(ret,(m << 1));
for(register int i = max(k,2);i <= cnt && (long long)prime[i] * prime[i] <= n;++i)
{
long long pw = prime[i],pw2 = (long long)prime[i] * prime[i];
for(register int c = 1;pw2 <= n;++c,pw = pw2,pw2 *= prime[i])
{
if(w[i][c])
ret = (ret + (long long)w[i][c] * F(i + 1,n / pw) % mod) % mod;
ret = add(ret,w[i][c + 1]);
}
}
return ret;
}
void main()
{
lim = sqrt(n);
for(register int i = 2;i <= MX;++i)
{
if(!vis[i])
{
prime[++cnt] = i,Gprime[cnt] = add(Gprime[cnt - 1],i);
for(register int j = 0;cnt > 1 && j <= k;++j)
for(register int u = 0,pw = 1;u <= k * (k + 1) / 2;++u,pw = (long long)pw * i % mod)
w[cnt][j] = (w[cnt][j] + (long long)DP::cnt[j][u] * pw) % mod;
}
for(register int j = 1;j <= cnt && i * prime[j] <= MX;++j)
{
vis[i * prime[j]] = 1;
if(!(i % prime[j]))
break;
}
}
for(register long long l = 1,r;l <= n;l = r + 1)
{
r = n / (n / l);
lis[id(n / l) = ++tot] = n / l;
G[tot] = (n / l % mod + 2) * (n / l % mod - 1 + mod) % mod * inv % mod;
}
for(register int k = 1;k <= cnt;++k)
{
int p = prime[k];
long long s = (long long)p * p;
for(register int i = 1;lis[i] >= s;++i)
G[i] = (G[i] - (long long)p * (G[id(lis[i] / p)] - Gprime[k - 1] + mod) % mod + mod) % mod;
}
for(register int i = 1;i <= tot;++i)
Fprime[i] = (long long)m * G[i] % mod;
printf("%d\n",(F(1,n) + 1) % mod);
}
}
int main()
{
freopen("number.in","r",stdin),freopen("number.out","w",stdout);
DP::main(),Min_25::main();
}