以下设 \(n = |S|\)。
设 \(p,q\) 为 \(S\) 的周期,\(p + q \le n\),则 \(\gcd(p,q)\) 也是 \(S\) 的周期。
不妨设 \(p > q\),则令 \(d = p - q\)。
对于 \(i - q > 0\),有 \(S_i = S_{i-q} = S_{i+p-q} = S_{i+d}\)。
对于 \(i + p \le n\),有 \(S_i = S_{i+p} = S_{i+p-q} = S_{i+d}\)。
当 \(p + q \le n\) 时不存在 \(i\) 同时不满足以上两者。
\(S\) 的所有长度不小于 \(\frac n2\) 的 border 的长度为一个等差数列。
设 \(p\) 为 \(S\) 的最小周期,不妨设 \(p \le \frac n2\)。
设 \(q\) 为 \(S\) 的一个周期,满足 \(q \le \frac n2\)。
则由弱周期引理可知 \(\gcd(p,q)\) 同为 \(S\) 的周期。
由上可知 \(\gcd(p,q) \ge p\)。
故显然 \(p \mid q\)。
字符串 \(S\) 的所有 border 按长度排序后可以划分为 \(O(\log n)\) 段等差数列。
将 \(S\) 的 border 长度 \(x\) 分类为 \([1,2),[2,4),\dots,[2^{k-1},2^k),[2^k,n)\qquad (2^k \ge \frac n2)\)。
对于 \(x \in [2^k,n)\),由以上引理可得 \(x\) 构成一个等差数列。
对于 \(x \in [2^{i-1},2^i)\):
设 \(m\) 为 \(\max x\),易知剩下的 \(x\) 也为 \(S_{1\dots m}\) 的 border。
而 \(x \ge \frac m2\),故由上引理易知 \(x\) 构成一个等差数列。
按时间分治,设当前在处理 \([s,e]\) 内的修改,我们注意到每个修改会导致树上一条祖先后代链的数对被删除。
假设每次是从一个点 \(v\) 往上删到一点 \(u\),我们不妨将所有 \(u\) 的父亲和所有 \(v\) 拿出来建出虚树,其余的边都可以扔掉。
途中需要用线段树来维护新增的数对。
直接实现就是 \(O((N+Q)\log^2(N+Q))\) 的。
然而事实上并不需要显式建树,注意到可以用单调栈来在按 DFS 序遍历的时候(也就是按左端点升序遍历)维护每个点的所有祖先。
可能会好写一点(大概)。
另,zkw 线段树在这题上很有用。
代码(从某份 std 那里抄来了 zkw): 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5;
const int Q = 2.5e5;
int n, q;
int a[N + 5], las[N + 5];
pair<int, int> upd[N + Q + 5];
int pre[N + Q + 5], nxt[N + Q + 5];
struct SegmentTree {
static const int S = 1 << 17;
int seg[(S << 1) + 5];
void insert(int u, int k) {
seg[u += S] = k;
for (u >>= 1; u; u >>= 1)
seg[u] = max(seg[u << 1], seg[u << 1 | 1]);
}
int queryL(int u, int k) {
for (u += S; seg[u] < k && (u & (u - 1)); u = (u - 1) >> 1);
if (seg[u] < k) return 0;
for (; u < S; u |= seg[u | 1] >= k) u <<= 1;
return seg[u] == k ? u - S : 0;
}
int queryR(int u, int k) {
for (u += S; seg[u] < k && (u & (u + 1)); u = (u + 1) >> 1);
if (seg[u] < k) return n + 1;
for (; u < S; u |= seg[u] < k) u <<= 1;
return seg[u] == k ? u - S : n + 1;
}
} seg;
struct node {
int l, r, h, w;
node(int l = 0, int r = 0, int h = 0, int w = 0): l(l), r(r), h(h), w(w) {}
bool operator<(const node &o) const { return l < o.l; }
bool operator==(const node &o) const { return l == o.l; }
};
pair<vector<node>, int> build(vector<int> newStatic, vector<pair<int, int>> upd, vector<node> tr) {
sort(newStatic.begin(), newStatic.end()), sort(upd.begin(), upd.end());
static int anc[N + 5];
static node newTr[N * 2 + 5];
int j = 0, top = 0, tot = 0;
for (auto [i, _]: upd) seg.insert(i, 0);
for (int i: newStatic) {
seg.insert(i, a[i]);
for (; j < tr.size() && tr[j].l <= i; ++j) {
auto [l, r, h, w] = tr[j];
for (; top && tr[anc[top]].r <= r; --top) newTr[tot++] = tr[anc[top]];
anc[++top] = j;
}
for (; top && tr[anc[top]].r < i; --top) newTr[tot++] = tr[anc[top]];
for (; top && tr[anc[top]].h <= a[i]; --top);
}
for (; top; --top) newTr[tot++] = tr[anc[top]];
for (; j < tr.size(); ++j) newTr[tot++] = tr[j];
for (int i: newStatic) {
int l = seg.queryL(i - 1, a[i]);
if (l) newTr[tot++] = node(l, i, a[i], 1);
int r = seg.queryR(i + 1, a[i]);
if (r <= n) newTr[tot++] = node(i, r, a[i], 1);
}
tr = vector<node>(newTr, newTr + tot), sort(tr.begin(), tr.end()), tr.erase(unique(tr.begin(), tr.end()), tr.end());
vector<int> sum(tr.size());
int cnt = 0;
for (int i = 0; i < tr.size(); ++i) {
auto [l, r, h, w] = tr[i];
for (; top && tr[anc[top]].r <= r; --top);
if (top) sum[i] = sum[anc[top]];
sum[i] += w, cnt += w;
anc[++top] = i;
}
vector<int> crit(1, 0);
crit.reserve(upd.size() + 1);
j = top = 0;
for (auto [i, h]: upd) {
for (; j < tr.size() && tr[j].l <= i; ++j) {
auto [l, r, h, w] = tr[j];
for (; top && tr[anc[top]].r <= r; --top);
anc[++top] = j;
}
for (; top && tr[anc[top]].r < i; --top);
crit.push_back(anc[top]);
int l = 1, r = top, mid, res = 1;
while (l <= r) {
mid = (l + r) >> 1;
if (tr[anc[mid]].h > h) l = mid + 1, res = mid;
else r = mid - 1;
}
crit.push_back(anc[res]);
}
sort(crit.begin(), crit.end()), crit.erase(unique(crit.begin(), crit.end()), crit.end());
crit.reserve(crit.size() * 2);
j = top = 0;
for (int i = 0, siz = crit.size(); i + 1 < siz; ++i) {
for (; j <= crit[i]; ++j) {
auto [l, r, h, w] = tr[j];
for (; top && tr[anc[top]].r <= r; --top);
anc[++top] = j;
}
for (; top && tr[anc[top]].r <= tr[crit[i + 1]].r; --top);
crit.push_back(anc[top]);
}
sort(crit.begin(), crit.end()), crit.erase(unique(crit.begin(), crit.end()), crit.end());
top = 0;
for (int i = 0; i < crit.size(); ++i) {
auto [l, r, h, w] = tr[crit[i]];
for (; top && tr[anc[top]].r <= r; --top);
w = sum[crit[i]];
if (top) w -= sum[anc[top]];
cnt -= w;
anc[++top] = crit[i];
newTr[i] = node(l, r, h, w);
}
return {vector<node>(newTr, newTr + crit.size()), cnt};
}
int ans[N + Q + 5];
void solve(int l, int r, vector<node> tr) {
if (l == r) {
a[upd[l].first] = upd[l].second;
int tag = build({upd[l].first}, {}, tr).second;
ans[l] += tag, ans[l + 1] -= tag;
return ;
}
int mid = (l + r) >> 1;
vector<int> tmp;
vector<pair<int, int>> updTmp;
static bool vis[N + 5];
if (mid >= n) {
for (int i = mid + 1; i <= r; ++i) if (pre[i] < l) tmp.push_back(upd[i].first), vis[upd[i].first] = 1;
updTmp = vector<pair<int, int>>(upd + l, upd + mid + 1);
for (int i = l; i <= mid; ++i) if (pre[i] && pre[i] < l && !vis[upd[pre[i]].first]) updTmp.push_back(upd[pre[i]]);
for (int i: tmp) vis[i] = 0;
auto [L, lTag] = build(tmp, updTmp, tr);
ans[l] += lTag, ans[mid + 1] -= lTag;
vector<int>().swap(tmp), vector<pair<int, int>>().swap(updTmp);
solve(l, mid, L);
}
for (int i = l; i <= mid; ++i) if (nxt[i] > r) tmp.push_back(upd[i].first), vis[upd[i].first] = 1;
updTmp = vector<pair<int, int>>(upd + mid + 1, upd + r + 1);
for (int i = mid + 1; i <= r; ++i) if (pre[i] && pre[i] <= mid && !vis[upd[pre[i]].first]) updTmp.push_back(upd[pre[i]]);
for (int i: tmp) vis[i] = 0;
auto [R, rTag] = build(tmp, updTmp, tr);
ans[mid + 1] += rTag, ans[r + 1] -= rTag;
solve(mid + 1, r, R);
}
int main() {
scanf("%d%d", &n, &q);
a[0] = a[n + 1] = inf;
for (int i = 1; i <= n; ++i)
scanf("%d", a + i),
upd[i] = make_pair(i, a[i]), las[i] = i;
for (int i = n + 1; i <= n + q; ++i)
scanf("%d%d", &upd[i].first, &upd[i].second),
pre[i] = las[upd[i].first], las[upd[i].first] = i;
fill(las + 1, las + n + 1, n + q + 1);
for (int i = n + q; i; --i) nxt[i] = las[upd[i].first], las[upd[i].first] = i;
solve(1, n + q, {node(0, n + 1, inf, 0)});
for (int i = 1; i <= n + q; ++i) ans[i] += ans[i - 1];
for (int i = n; i <= n + q; ++i) printf("%d\n", ans[i]);
}
建图,对于逆序对 \(i < j, A_i > A_j\) 连边 \(i \to j\)。这也叫做排列图(Permutation Graph)。则每次就是删去最多两个入度为 \(0\) 的点。
对于任意 DAG,有论文给出了这样一个方法:
这样我们就至少获得了一个多项式时间的做法。考虑在排列图上如何优化之。
首先,计算出 \(level(i)\) 表示以 \(i\) 结尾的最长路。也就是以 \(i\) 结尾的最长下降子序列长度。
我们按照 \(level(i)\) 升序处理所有点,不难发现这样和拓扑排序是等价的,且同层间没有连边。
那么,分层做,我们只需要支持对任意 \(v,w\) 比较 \(N(v),N(w)\) 的字典序即可。由于已经降序,所以我们事实上只需要比较 \(N(v)-N(w),N(w)-N(v)\) 中的最大值。
这是单点修改,矩形 \(\max\),用树套树即可做到单次 \(O(\log^2 n)\),总共 \(O(n\log^3 n)\)。
还是有点慢,但我们发现如果我们需要先求这个 3-side 矩形中点的最大 \(level(i)\)(这是静态的),是可以用主席树 \(O(n\log n)\) 预处理 \(O(\log n)\) 查询的(不过原论文好像用的是划分树状物,可能这就是学术界做法吧)。
而同 \(level\) 的点一定是反链,所以可以转化为区间 \(\max\) 查询。
这样就做到了 \(O(n\log^2 n)\)。
但是这似乎还是太 naive。
(似乎)可以观察到最小时间等于 \(N\) 减去补图的最大匹配。进一步,有论文证明了本题的答案和补图的最大匹配是双射(具体的映射应该很好猜),那么我们对于一组最大匹配只需要一直找入度为 \(0\) 的点就可以排出一个操作顺序。这是 \(O(n\log n)\) 的。
而目前有论文声称可以在 \(O(n\log \log n)\) 时间内计算排列图(显然排列图的补图也是排列图)的最大匹配,感觉很 nb。
代码(实现了 \(3\log\) 的做法): 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5;
int n, p[N + 5], q[N + 5];
int level[N + 5];
int a[N + 5], b[N + 5];
struct BinaryIndexedTree {
int c[N + 5];
void update(int x, int k) {
for (; x; x &= x - 1) c[x] = max(c[x], k);
}
int query(int x) {
int ret = 0;
for (; x <= n; x += x & -x) ret = max(ret, c[x]);
return ret;
}
} bit;
struct TreeOfTree {
struct segnode {
int max;
int ls, rs;
} seg[N * 700 + 5];
void insert(int x, int k, int &u, int tl, int tr) {
static int tot = 0;
if (!u) u = ++tot;
seg[u].max = max(seg[u].max, k);
if (tl == tr) return ;
int mid = (tl + tr) >> 1;
if (x <= mid) insert(x, k, seg[u].ls, tl, mid);
else insert(x, k, seg[u].rs, mid + 1, tr);
}
int query(int l, int r, int u, int tl, int tr) {
if (!u || (l <= tl && tr <= r)) return seg[u].max;
int mid = (tl + tr) >> 1;
int ret = 0;
if (l <= mid) ret = max(ret, query(l, r, seg[u].ls, tl, mid));
if (r > mid) ret = max(ret, query(l, r, seg[u].rs, mid + 1, tr));
return ret;
}
int rt[N * 4 + 5];
void insert(int x, int y, int k, int u, int tl, int tr) {
insert(y, k, rt[u], 1, n);
if (tl == tr) return ;
int mid = (tl + tr) >> 1;
if (x <= mid) insert(x, y, k, ls, tl, mid);
else insert(x, y, k, rs, mid + 1, tr);
}
void insert(int x, int y, int k) { insert(x, y, k, 1, 1, n); }
int query(int l, int r, int x, int y, int u, int tl, int tr) {
if (l <= tl && tr <= r) return query(x, y, rt[u], 1, n);
int mid = (tl + tr) >> 1;
int ret = 0;
if (l <= mid) ret = max(ret, query(l, r, x, y, ls, tl, mid));
if (r > mid) ret = max(ret, query(l, r, x, y, rs, mid + 1, tr));
return ret;
}
int query(int l, int r, int x, int y) { return l <= r && x <= y ? query(l, r, x, y, 1, 1, n) : 0; }
} tr;
struct SegmentTree {
struct node {
int min;
bool vis, tag;
} seg[N * 4 + 5];
void build(int u, int tl, int tr) {
if (tl == tr) { seg[u].min = p[tl]; return ; }
int mid = (tl + tr) >> 1;
build(ls, tl, mid), build(rs, mid + 1, tr);
seg[u].min = min(seg[ls].min, seg[rs].min);
}
void build() { build(1, 1, n); }
int prev(int x, int u, int tl, int tr) {
if (seg[u].min >= p[x]) return 0;
if (tl == tr) return tl;
int mid = (tl + tr) >> 1;
if (x <= mid) return prev(x, ls, tl, mid);
int ret = prev(x, rs, mid + 1, tr);
if (!ret) ret = prev(x, ls, tl, mid);
return ret;
}
int prev(int x) { return prev(x, 1, 1, n); }
void pushDown(int u) {
if (seg[u].tag) {
seg[ls].vis = 0, seg[ls].tag = 1;
seg[rs].vis = 0, seg[rs].tag = 1;
seg[u].tag = 0;
}
}
void undo(int l, int r, int u, int tl, int tr) {
seg[u].vis = 0;
if (l <= tl && tr <= r) { seg[u].tag = 1; return ; }
pushDown(u);
int mid = (tl + tr) >> 1;
if (l <= mid) undo(l, r, ls, tl, mid);
if (r > mid) undo(l, r, rs, mid + 1, tr);
}
void undo(int l, int r) { undo(l, r, 1, 1, n); }
void erase(int x, int u, int tl, int tr) {
if (tl == tr) { seg[u].min = inf; return ; }
int mid = (tl + tr) >> 1;
if (x <= mid) erase(x, ls, tl, mid);
else erase(x, rs, mid + 1, tr);
seg[u].min = min(seg[ls].min, seg[rs].min);
}
void erase(int x) {
int y = prev(x, 1, 1, n);
undo(y + 1, x, 1, 1, n), erase(x, 1, 1, n);
}
void search(vector<int> &ret, int rmin, int u, int tl, int tr) {
if (seg[u].min >= rmin || seg[u].vis) return ;
seg[u].vis = 1;
if (tl == tr) { ret.push_back(tl); return ; }
pushDown(u);
int mid = (tl + tr) >> 1;
search(ret, rmin, rs, mid + 1, tr);
search(ret, min(rmin, seg[rs].min), ls, tl, mid);
}
vector<int> search() {
vector<int> ret;
search(ret, inf, 1, 1, n);
return ret;
}
} seg;
vector<pair<int, int>> ans;
priority_queue<int> que;
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) scanf("%d", p + i), q[p[i]] = i;
for (int i = 1; i <= n; ++i)
bit.update(p[i], (level[i] = bit.query(p[i] + 1)) + 1);
iota(a + 1, a + n + 1, 1), sort(a + 1, a + n + 1, [](int x, int y) { return level[x] < level[y]; });
for (int l = 1, r; l <= n; l = r + 1) {
for (r = l; r < n && level[a[r + 1]] == level[a[l]]; ++r);
sort(a + l, a + r + 1, [](int x, int y) {
bool flag = 0;
if (x > y) flag = 1, swap(x, y);
int a = tr.query(1, x - 1, p[x] + 1, p[y]), b = tr.query(x + 1, y - 1, p[y] + 1, n);
return a != b ? (a < b) ^ flag : 0;
});
for (int i = l; i <= r; ++i) tr.insert(a[i], p[a[i]], i);
}
for (int i = 1; i <= n; ++i) b[a[i]] = i;
seg.build();
for (;;) {
auto res = seg.search();
for (int i: res) que.push(b[i]);
if (que.empty()) break;
int x = que.top(); que.pop(), seg.erase(a[x]);
int y = x;
if (!que.empty()) y = que.top(), que.pop(), seg.erase(a[y]);
ans.emplace_back(p[a[x]], p[a[y]]);
}
reverse(ans.begin(), ans.end());
printf("%d\n", ans.size());
for (auto i: ans) printf("%d %d\n", i.first, i.second);
}
考虑枚举序列的标准基底。设秩为 \(k\),其中元素的最高位分别为 \(a_0 < a_1 < \dots < a_{k-1}\)。
于是,贡献由三部分组成:秩为 \(k\) 的序列个数、标准基底个数,以及最大异或和的期望值。
秩为 \(k\) 的序列个数是经典问题,在此处已经讨论过(不过这里先不考虑基是哪些元素),于是直接给出结论: \[2^{k(k-1)/2} (k!)_2 \binom nk_2\]
对应的标准基底个数,由于除了钦定的 \(k\) 位以外其他位都可以随意选定,于是显然就是 \[\prod_{i=0}^{k-1} 2^{a_i-i} = 2^{-k(k-1)/2} \prod_{i=0}^{k-1} 2^{a_i}\]
最大异或和的期望值,也就是基底的期望异或和。其中钦定的 \(k\) 位必然为 \(1\),其余可能出现在基底中的位各有 \(1/2\) 的概率为 \(1\)。于是: \[\frac12\left(2^{a_{k-1}+1}-1+\sum_{i=0}^{k-1}2^{a_i}\right)\]
那么我们主要要计算 \[\frac12\left(2^{a_{k-1}+1}-1+\sum_{i=0}^{k-1}2^{a_i}\right) \prod_{i=0}^{k-1} 2^{a_i}\]
对于 \(-1/2\) 项,有 \[-\frac12 [x^k] \prod_{i=0}^{m-1}(1+2^i x) = -2^{k(k-1)/2-1} \binom mk_2\]
对于 \(\frac12 \sum_i 2^{a_i}\) 项,一个很棒的做法是看成 \(2^m-1\) 中减去其他不在基中的项,也就是可以看成选 \(k+1\) 个: \[(2^m-1) 2^{k(k-1)/2-1} \binom mk_2 - (k+1) 2^{k(k+1)/2-1} \binom m{k+1}_2\]
对于 \(2^{a_{k-1}}\) 项,我们可以另列 \[F(x) = \sum_{r=0}^{m-1} 2^{2r} \prod_{i=0}^{r-1} (1+2^ix)\]
这也是 q-D-finite 的。当然,如果你愿意,同样可以表为 q-二项式系数。
]]>起手式是做 \(s_i = \bigoplus_{j=1}^i a_j\),然后对答案数组来一个反向差分(为了方便),则问题变为对于 \([l,r]\) 内满足 \(s_i \oplus s_j = k\) \((i < j)\) 的数对操作 \(b_i \gets b_i + w, b_j \gets b_j - w\)。
好,我们知道有一个题是查询区间内上述数对的个数。那么我们不妨沿用原来的做法,即莫队。则其中一个端点的操作可以现场完成,并对另一个端点会转化出 \(O(n \sqrt m)\) 个四参数的操作 \((l,r,x,w)\) 表示将 \([l,r]\) 内 \(s_i=x\) 的 \(i\) 作 \(s_i \gets s_i+w\)。
当然,因为莫队本身只刻画不同询问间的位移,所以为了能计算答案我们要把所有 \(w\) 后缀和。
自然将不同的 \(s_i\) 划分开,在其中再次差分,则问题变为一些 lower_bound 查询。
带 \(\log\) 肯定是不好的。一个做法是离线,但是那样难免会花费同复杂度的空间,不太划算。事实上莫队本身就在做若干次 \(O(1)\) 位移,这意味着当前的左右端点在任一 \(s_i\) 的等价类中的 lower_bound 都可以同时维护,因为每个位置只从属于一个等价类。
过了几天想了一下,其实还有另一个比较好写的做法。注意到莫队本身就是左右端点一直位移的过程,那么假设我们能将莫队整个扫描过程逆序,那其实就能将对应的修改直接在之后的反向扫描中下放到对应位置上。
其实这是通过转置原理得出的(
常数相当小。在老年 OJ 上仍然能跑出 240ms 的好成绩。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
using uint = unsigned int;
using namespace std;
const int N = 1.5e5;
int n, m, k;
int s[N + 5], sId[N + 5], sXorK[N + 5], tot;
unordered_map<int, int> id;
int block, pos[N + 5];
struct operation {
int l, r;
uint w;
inline bool operator<(const operation &o) const {
return pos[l] != pos[o.l] ? pos[l] < pos[o.l] : pos[l] & 1 ? r < o.r : r > o.r;
}
} opt[N + 5];
int cnt[N + 5];
uint sum[N + 5], ans[N + 5];
int main() {
freopen("xor.in", "r", stdin), freopen("xor.out", "w", stdout);
scanf("%d%d%d", &n, &m, &k), block = min<int>(n / sqrt(m) + 1, n);
sId[0] = id[0] = ++tot;
for (int i = 1; i <= n; ++i)
scanf("%d", s + i), s[i] ^= s[i - 1], pos[i] = (i - 1) / block + 1,
sId[i] = id.count(s[i]) ? id[s[i]] : (id[s[i]] = ++tot);
for (int i = 0; i <= n; ++i)
if (id.count(s[i] ^ k))
sXorK[i] = id[s[i] ^ k];
for (int i = 1; i <= m; ++i)
scanf("%d%d%u", &opt[i].l, &opt[i].r, &opt[i].w);
sort(opt + 1, opt + m + 1), opt[0].l = 1;
for (int i = m - 1; i; --i)
opt[i].w += opt[i + 1].w;
int l = 1, r = 0;
++cnt[sId[0]];
for (int i = 1; i <= m; ++i) {
while (r < opt[i].r)
++r, ans[r] += opt[i].w * cnt[sXorK[r]], ++cnt[sId[r]];
while (r > opt[i].r)
--cnt[sId[r]], ans[r] -= opt[i].w * cnt[sXorK[r]], --r;
while (l < opt[i].l)
--cnt[sId[l - 1]], ans[l - 1] += opt[i].w * cnt[sXorK[l - 1]], ++l;
while (l > opt[i].l)
--l, ans[l - 1] -= opt[i].w * cnt[sXorK[l - 1]], ++cnt[sId[l - 1]];
}
for (int i = m; i; --i) {
while (l < opt[i - 1].l)
ans[l - 1] += sum[sId[l - 1]], sum[sXorK[l - 1]] += opt[i].w, ++l;
while (l > opt[i - 1].l)
--l, sum[sXorK[l - 1]] -= opt[i].w, ans[l - 1] -= sum[sId[l - 1]];
while (r < opt[i - 1].r)
++r, sum[sXorK[r]] += opt[i].w, ans[r] -= sum[sId[r]];
while (r > opt[i - 1].r)
ans[r] += sum[sId[r]], sum[sXorK[r]] -= opt[i].w, --r;
}
for (int i = n - 1; i; --i)
ans[i] += ans[i + 1];
for (int i = 1; i <= n; ++i)
printf("%u%c", ans[i] & ((1U << 30) - 1), " \n"[i == n]);
}
事实上我在场上使用多元拉格朗日反演得出了结论,不过在此不表。
考虑一个经典的组合意义:转化为从每个环中选出一个元素的方案数。
枚举 \(c_i\) 表示长度为 \(i\) 的链被选择多少次,则其同时被钦定不在同一环内(这个方案数可以直接通过简单的组合意义得到): \[\sum_{0\le c_i \le k_i} \left(\prod_{j=1}^n \binom{k_j}{c_j}j^{c_j}\right) \frac{(\sum_j k_j-1)!}{(\sum_j c_j-1)!}\]
也就是 \[\left(\sum_{j=1}^n k_j-1\right)! \sum_{i\ge 1} \frac1{(i-1)!} [x^i] \prod_{j=1}^n (1+jx)^{k_j}\]
我们把其放到容斥中,那么看起来我们务必要再用一元计量 \(\sum_j k_j\)。
也就是从环上每断出长 \(l\) 的链会贡献 \((-1)^{l-1} t(1+lx)\)。为了快速计算,我们再用一元 \(u\) 计量环长。当然这里是当做链来处理了,之后每一项要乘环长再除以链数(也就是 \(t\) 的次数)。
注意到 \[\sum_{l\ge 0}(-1)^{l-1} t(1+lx)u^l = \frac{tu}{1+u} + \frac{xtu}{(1+u)^2} = \frac{tu(1+x+u)}{(1+u)^2}\]
拼成链 \[\frac1{1-\frac{tu(1+x+u)}{(1+u)^2}}\]
其任意一项系数是容易提取的(可以按照 \(t,x,u\) 的顺序提取),于是分治并做二维卷积计算系数即可。
为了偷懒我直接写了映射到一维来实现二维的卷积,不过这样会有个问题就是卷积长度达到了 \(2^{22}\),板子里的某个优化就用不了了(不过 DIT DIF 当然还是能写的)。
所以大概常数大了点。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
using ll = long long;
using namespace std;
const int mod = 998244353;
inline int norm(int x) {
return x >= mod ? x - mod : x;
}
inline int reduce(int x) {
return x < 0 ? x + mod : x;
}
inline int neg(int x) {
return x ? mod - x : 0;
}
inline void add(int &x, int y) {
if ((x += y - mod) < 0)
x += mod;
}
inline void sub(int &x, int y) {
if ((x -= y) < 0)
x += mod;
}
inline void fam(int &x, int y, int z) {
x = (x + (ll)y * z) % mod;
}
inline int qpow(int a, int b) {
int ret = 1;
for (; b; b >>= 1)
(b & 1) && (ret = (ll)ret * a % mod),
a = (ll)a * a % mod;
return ret;
}
namespace Poly {
const int LG = 21;
const int N = 1 << LG + 1;
const int G = 3;
int lg2[N + 5];
int fac[N + 5], ifac[N + 5], inv[N + 5];
int rt[N + 5];
inline void init() {
for (int i = 2; i <= N; ++i)
lg2[i] = lg2[i >> 1] + 1;
int w = qpow(G, (mod - 1) >> LG + 1);
rt[N >> 1] = 1;
for (int i = (N >> 1) + 1; i <= N; ++i)
rt[i] = (ll)rt[i - 1] * w % mod;
for (int i = (N >> 1) - 1; i; --i)
rt[i] = rt[i << 1];
fac[0] = 1;
for (int i = 1; i <= N; ++i)
fac[i] = (ll)fac[i - 1] * i % mod;
ifac[N] = qpow(fac[N], mod - 2);
for (int i = N; i; --i)
ifac[i - 1] = (ll)ifac[i] * i % mod;
for (int i = 1; i <= N; ++i)
inv[i] = (ll)ifac[i] * fac[i - 1] % mod;
}
struct poly {
vector<int> a;
inline poly(int x = 0) {
if (x)
a.push_back(x);
}
inline poly(const vector<int> &o) {
a = o;
}
inline poly(const poly &o) {
a = o.a;
}
inline int size() const {
return a.size();
}
inline bool empty() const {
return a.empty();
}
inline void resize(int x) {
a.resize(x);
}
inline int operator[](int x) const {
if (x < 0 || x >= size())
return 0;
return a[x];
}
inline void clear() {
vector<int>().swap(a);
}
inline poly modxn(int n) const {
if (a.empty())
return poly();
n = min(n, size());
return poly(vector<int>(a.begin(), a.begin() + n));
}
inline poly rever() const {
return poly(vector<int>(a.rbegin(), a.rend()));
}
inline void dif() {
int n = size();
for (int len = n >> 1; len; len >>= 1)
for (int j = 0; j < n; j += len << 1)
for (int k = j, *w = rt + len; k < j + len; ++k, ++w) {
int R = norm(a[k] + a[k + len]);
a[k + len] = (ll)*w * (a[k] - a[k + len] + mod) % mod,
a[k] = R;
}
}
inline void dit() {
int n = size();
for (int len = 1; len < n; len <<= 1)
for (int j = 0; j < n; j += len << 1)
for (int k = j, *w = rt + len; k < j + len; ++k, ++w) {
int R = (ll)*w * a[k + len] % mod;
a[k + len] = reduce(a[k] - R),
add(a[k], R);
}
reverse(a.begin() + 1, a.end());
for (int i = 0; i < n; ++i)
a[i] = (ll)a[i] * inv[n] % mod;
}
inline void ntt(int type = 1) {
type == 1 ? dif() : dit();
}
};
struct poly2D {
int n, m;
vector<int> a;
inline poly2D(int r = 0, int s = 0): n(r), m(s) {
a.resize(n * m);
}
inline poly2D(int r, int s, vector<int> vec): n(r), m(s), a(vec) {}
inline int size() const {
return a.size();
}
inline bool empty() const {
return a.empty();
}
inline void resize(int r, int s) {
n = r, m = s, a.resize(n * m);
}
inline int operator[](int x) const {
if (x < 0 || x >= size())
return 0;
return a[x];
}
friend inline poly2D operator*(const poly2D &a, const poly2D &b) {
int n = a.n + b.n - 1, m = a.m + b.m - 1;
poly aBuf, bBuf, resBuf;
poly2D ret(n, m);
int tot = n * m, lim = 1;
for (; lim < tot; lim <<= 1);
aBuf.resize(lim), bBuf.resize(lim), resBuf.resize(lim);
for (int i = 0; i < a.n * a.m; ++i) {
int x = i / a.m, y = i % a.m;
aBuf.a[x * m + y] = a[i];
}
for (int i = 0; i < b.n * b.m; ++i) {
int x = i / b.m, y = i % b.m;
bBuf.a[x * m + y] = b[i];
}
aBuf.ntt(), bBuf.ntt();
for (int i = 0; i < lim; ++i)
fam(resBuf.a[i], aBuf[i], bBuf[i]);
resBuf.ntt(-1);
for (int i = 0; i < tot; ++i)
ret.a[i] = resBuf[i];
return ret;
}
};
}
using Poly::fac;
using Poly::ifac;
using Poly::inv;
using Poly::init;
using Poly::poly;
using Poly::poly2D;
inline int binom(int n, int m) {
return n < m || m < 0 ? 0 : (ll)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
const int N = 2e3;
int n;
int p[N + 5];
struct UnionFind {
int fa[N + 5], size[N + 5];
inline UnionFind() {
for (int i = 1; i <= N; ++i)
size[i] = 1;
}
inline bool isRoot(int x) {
return !fa[x];
}
inline int find(int x) {
return isRoot(x) ? x : fa[x] = find(fa[x]);
}
inline void merge(int x, int y) {
int fx = find(x), fy = find(y);
if (fx != fy)
fa[fx] = fy, size[fy] += size[fx];
}
} uf;
struct comparer {
inline bool operator()(const poly2D &x, const poly2D &y) {
return x.size() > y.size();
}
};
priority_queue<poly2D, vector<poly2D>, comparer> q;
int ans = 1;
int main() {
freopen("C.in", "r", stdin), freopen("C.out", "w", stdout);
init();
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", p + i), uf.merge(i, p[i]);
for (int i = 1; i <= n; ++i)
if (uf.isRoot(i)) {
int size = uf.size[i];
ans = (ll)ans * size % mod;
poly2D buf(size + 1, size + 1);
for (int k = 1; k <= size; ++k)
for (int t = 0; t <= k; ++t) {
buf.a[k * (size + 1) + t] = (ll)binom(size + t - 1, t + k - 1) * binom(k, t) % mod * size % mod * inv[k] % mod;
if ((size - k) & 1)
buf.a[k * (size + 1) + t] = neg(buf[k * (size + 1) + t]);
}
buf.a[0] = size & 1 ? neg(size) : size;
q.push(buf);
}
if (n & 1)
ans = neg(ans);
while (q.size() > 1) {
poly2D x = q.top();
q.pop(), x = x * q.top(), q.pop(), q.push(x);
}
poly2D f = q.top();
for (int k = 1; k <= n; ++k)
for (int t = 1; t <= k; ++t)
ans = (ans + (ll)fac[k - 1] * ifac[t - 1] % mod * f[k * (n + 1) + t]) % mod;
printf("%d\n", ans);
}
引理.
模质数 \(q\) 意义下 \(n \times m\) 的秩为 \(r\) 的矩阵的个数为 \[q^{r(r-1)/2} (r!)_q \binom nr_q \binom mr_q\]
证明:
显然秩为 \(r\) 的 \(r \times m\) 矩阵个数为 \[\prod_{i=0}^{r-1} (q^m - q^i) = q^{r(r-1)/2} (r!)_q \binom mr_q\]
插入剩下的数的方案数是 \[[z^{n-r}] \prod_{i=0}^r \frac1{1-q^iz}\]
用类似 \(q\)-二项式定理证明的方法可知其等于 \(\binom nr_q\)。
来看这题。
考虑斯特林容斥,这样去掉两个限制的 \(i\times j\times H\) 的立方体的个数就是 \[\prod_{k=1}^H q^{r_k(r_k-1)/2} (r_k!)_q \binom i{r_k}_q \binom j{r_k}_q\]
可以提出 \(\prod_{k=1}^H q^{r_k(r_k-1)/2} (r_k!)_q\),那么除此之外总共就是 \[\left(\sum_{i=1}^L (-1)^{L-i} {L\brack i} \prod_{k=1}^H \binom i{r_k}_q\right)\left(\sum_{j=1}^W (-1)^{W-j} {W \brack j} \prod_{k=1}^H \binom j{r_k}_q\right)\]
这东西不像能结合起来算的样子,所以考虑算出所有 \[f_i = \prod_{k=1}^H \binom i{r_k}_q\]
注意到 \[\frac{f_i}{ f_{i-1} } = \prod_{k=1}^H \frac{1-q^i}{ 1-q^{i-r_k} }\]
那么用 CZT 算出 \[\prod_{k=1}^H (1-q^{-r_k}z) = \exp\left(-\sum_{i\ge 1} \frac{z^i}i \sum_{k=1}^H q^{-ir_k}\right)\]
之后再做一次 CZT 即可。
接下来算答案。
就是计算 \[\newcommand\me{ \mathrm e }\sum_{i\ge 0} f_i [t^i] \me^{t\ln(1+z)} = \sum_{i\ge 0} f_i [t^i] (1+z)^t\]
转置一下可以发现就是下降幂转普通幂: \[\sum_{i\ge 0} f_i [z^i] \me^{t\ln(1+z)} = \sum_{i\ge 0} f_i \binom ti\]
分治 NTT 即可。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
using ll = long long;
using namespace std;
const int mod = 998244353;
inline int norm(int x) {
return x >= mod ? x - mod : x;
}
inline int reduce(int x) {
return x < 0 ? x + mod : x;
}
inline int neg(int x) {
return x ? mod - x : 0;
}
inline void add(int &x, int y) {
if ((x += y - mod) < 0)
x += mod;
}
inline void sub(int &x, int y) {
if ((x -= y) < 0)
x += mod;
}
inline void fam(int &x, int y, int z) {
x = (x + (ll)y * z) % mod;
}
inline int qpow(int a, int b) {
int ret = 1;
for (; b; b >>= 1)
(b & 1) && (ret = (ll)ret * a % mod),
a = (ll)a * a % mod;
return ret;
}
const int N = 1e5;
namespace Poly {
const int LG = 18;
const int N = 1 << LG + 1;
const int G = 3;
int lg2[N + 5];
int fac[N + 5], ifac[N + 5], inv[N + 5];
int rt[N + 5];
inline void init() {
for (int i = 2; i <= N; ++i)
lg2[i] = lg2[i >> 1] + 1;
rt[0] = 1, rt[1 << LG] = qpow(G, (mod - 1) >> LG + 2);
for (int i = LG; i; --i)
rt[1 << i - 1] = (ll)rt[1 << i] * rt[1 << i] % mod;
for (int i = 1; i < N; ++i)
rt[i] = (ll)rt[i & i - 1] * rt[i & -i] % mod;
fac[0] = 1;
for (int i = 1; i <= N; ++i)
fac[i] = (ll)fac[i - 1] * i % mod;
ifac[N] = qpow(fac[N], mod - 2);
for (int i = N; i; --i)
ifac[i - 1] = (ll)ifac[i] * i % mod;
for (int i = 1; i <= N; ++i)
inv[i] = (ll)ifac[i] * fac[i - 1] % mod;
}
struct poly {
vector<int> a;
inline poly(int x = 0) {
if (x)
a.push_back(x);
}
inline poly(const vector<int> &o) {
a = o;
}
inline poly(const poly &o) {
a = o.a;
}
inline int size() const {
return a.size();
}
inline bool empty() const {
return a.empty();
}
inline void resize(int x) {
a.resize(x);
}
inline int operator[](int x) const {
if (x < 0 || x >= size())
return 0;
return a[x];
}
inline void clear() {
vector<int>().swap(a);
}
inline poly modxn(int n) const {
if (a.empty())
return poly();
n = min(n, size());
return poly(vector<int>(a.begin(), a.begin() + n));
}
inline poly rever() const {
return poly(vector<int>(a.rbegin(), a.rend()));
}
inline void dif() {
int n = size();
for (int i = 0, len = n >> 1; len; ++i, len >>= 1)
for (int j = 0, *w = rt; j < n; j += len << 1, ++w)
for (int k = j; k < j + len; ++k) {
int R = (ll)*w * a[k + len] % mod;
a[k + len] = reduce(a[k] - R),
add(a[k], R);
}
}
inline void dit() {
int n = size();
for (int i = 0, len = 1; len < n; ++i, len <<= 1)
for (int j = 0, *w = rt; j < n; j += len << 1, ++w)
for (int k = j; k < j + len; ++k) {
int R = norm(a[k] + a[k + len]);
a[k + len] = (ll)*w * (a[k] - a[k + len] + mod) % mod,
a[k] = R;
}
reverse(a.begin() + 1, a.end());
for (int i = 0; i < n; ++i)
a[i] = (ll)a[i] * inv[n] % mod;
}
inline void ntt(int type = 1) {
type == 1 ? dif() : dit();
}
friend inline poly operator+(const poly &a, const poly &b) {
vector<int> ret(max(a.size(), b.size()));
for (int i = 0; i < ret.size(); ++i)
ret[i] = norm(a[i] + b[i]);
return poly(ret);
}
friend inline poly operator-(const poly &a, const poly &b) {
vector<int> ret(max(a.size(), b.size()));
for (int i = 0; i < ret.size(); ++i)
ret[i] = reduce(a[i] - b[i]);
return poly(ret);
}
friend inline poly operator*(poly a, poly b) {
if (a.empty() || b.empty())
return poly();
if (a.size() < 40 || b.size() < 40) {
if (a.size() > b.size())
swap(a, b);
poly ret;
ret.resize(a.size() + b.size() - 1);
for (int i = 0; i < ret.size(); ++i)
for (int j = 0; j <= i && j < a.size(); ++j)
ret.a[i] = (ret[i] + (ll)a[j] * b[i - j]) % mod;
return ret;
}
int lim = 1, tot = a.size() + b.size() - 1;
for (; lim < tot; lim <<= 1);
a.resize(lim), b.resize(lim);
a.ntt(), b.ntt();
for (int i = 0; i < lim; ++i)
a.a[i] = (ll)a[i] * b[i] % mod;
a.ntt(-1), a.resize(tot);
return a;
}
poly &operator+=(const poly &o) {
resize(max(size(), o.size()));
for (int i = 0; i < o.size(); ++i)
add(a[i], o[i]);
return *this;
}
poly &operator-=(const poly &o) {
resize(max(size(), o.size()));
for (int i = 0; i < o.size(); ++i)
sub(a[i], o[i]);
return *this;
}
poly &operator*=(poly o) {
return (*this) = (*this) * o;
}
poly deriv() const {
if (empty())
return poly();
vector<int> ret(size() - 1);
for (int i = 0; i < size() - 1; ++i)
ret[i] = (ll)(i + 1) * a[i + 1] % mod;
return poly(ret);
}
poly integ() const {
if (empty())
return poly();
vector<int> ret(size() + 1);
for (int i = 0; i < size(); ++i)
ret[i + 1] = (ll)a[i] * inv[i + 1] % mod;
return poly(ret);
}
inline poly inver(int m) const {
poly ret(qpow(a[0], mod - 2)), f, g;
for (int k = 1; k < m;) {
k <<= 1, f.resize(k), g.resize(k);
for (int i = 0; i < k; ++i)
f.a[i] = operator[](i), g.a[i] = ret[i];
f.ntt(), g.ntt();
for (int i = 0; i < k; ++i)
f.a[i] = (ll)f[i] * g[i] % mod;
f.ntt(-1);
for (int i = 0; i < (k >> 1); ++i)
f.a[i] = 0;
f.ntt();
for (int i = 0; i < k; ++i)
f.a[i] = (ll)f[i] * g[i] % mod;
f.ntt(-1);
ret.resize(k);
for (int i = (k >> 1); i < k; ++i)
ret.a[i] = neg(f[i]);
}
return ret.modxn(m);
}
inline pair<poly, poly> div(poly o) const {
if (size() < o.size())
return make_pair(poly(), *this);
poly f, g;
f = (rever().modxn(size() - o.size() + 1) * o.rever().inver(size() - o.size() + 1))
.modxn(size() - o.size() + 1).rever();
g = (modxn(o.size() - 1) - o.modxn(o.size() - 1) * f.modxn(o.size() - 1)).modxn(o.size() - 1);
return make_pair(f, g);
}
inline poly log(int m) const {
return (deriv() * inver(m)).integ().modxn(m);
}
inline poly exp(int m) const {
poly ret(1), iv, it, d = deriv(), itd, itd0, t1;
if (m < 70) {
ret.resize(m);
for (int i = 1; i < m; ++i) {
for (int j = 1; j <= i; ++j)
ret.a[i] = (ret[i] + (ll)j * operator[](j) % mod * ret[i - j]) % mod;
ret.a[i] = (ll)ret[i] * inv[i] % mod;
}
return ret;
}
for (int k = 1; k < m;) {
k <<= 1;
it.resize(k >> 1);
for (int i = 0; i < (k >> 1); ++i)
it.a[i] = ret[i];
itd = it.deriv(), itd.resize(k >> 1);
iv = ret.inver(k >> 1), iv.resize(k >> 1);
it.ntt(), itd.ntt(), iv.ntt();
for (int i = 0; i < (k >> 1); ++i)
it.a[i] = (ll)it[i] * iv[i] % mod,
itd.a[i] = (ll)itd[i] * iv[i] % mod;
it.ntt(-1), itd.ntt(-1), sub(it.a[0], 1);
for (int i = 0; i < k - 1; ++i)
sub(itd.a[i % (k >> 1)], d[i]);
itd0.resize((k >> 1) - 1);
for (int i = 0; i < (k >> 1) - 1; ++i)
itd0.a[i] = d[i];
itd0 = (itd0 * it).modxn((k >> 1) - 1);
t1.resize(k - 1);
for (int i = (k >> 1) - 1; i < k - 1; ++i)
t1.a[i] = itd[(i + (k >> 1)) % (k >> 1)];
for (int i = k >> 1; i < k - 1; ++i)
sub(t1.a[i], itd0[i - (k >> 1)]);
t1 = t1.integ();
for (int i = 0; i < (k >> 1); ++i)
t1.a[i] = t1[i + (k >> 1)];
for (int i = (k >> 1); i < k; ++i)
t1.a[i] = 0;
t1.resize(k >> 1), t1 = (t1 * ret).modxn(k >> 1), t1.resize(k);
for (int i = (k >> 1); i < k; ++i)
t1.a[i] = t1[i - (k >> 1)];
for (int i = 0; i < (k >> 1); ++i)
t1.a[i] = 0;
ret -= t1;
}
return ret.modxn(m);
}
inline poly sqrt(int m) const {
poly ret(1), f, g;
for (int k = 1; k < m;) {
k <<= 1;
f = ret, f.resize(k >> 1);
f.ntt();
for (int i = 0; i < (k >> 1); ++i)
f.a[i] = (ll)f[i] * f[i] % mod;
f.ntt(-1);
for (int i = 0; i < k; ++i)
sub(f.a[i % (k >> 1)], operator[](i));
g = (2 * ret).inver(k >> 1), f = (f * g).modxn(k >> 1), f.resize(k);
for (int i = (k >> 1); i < k; ++i)
f.a[i] = f[i - (k >> 1)];
for (int i = 0; i < (k >> 1); ++i)
f.a[i] = 0;
ret -= f;
}
return ret.modxn(m);
}
inline poly pow(int m, int k1, int k2 = -1) const {
if (empty())
return poly();
if (k2 == -1)
k2 = k1;
int t = 0;
for (; t < size() && !a[t]; ++t);
if ((ll)t * k1 >= m)
return poly();
poly ret;
ret.resize(m);
int u = qpow(a[t], mod - 2), v = qpow(a[t], k2);
for (int i = 0; i < m - t * k1; ++i)
ret.a[i] = (ll)operator[](i + t) * u % mod;
ret = ret.log(m - t * k1);
for (int i = 0; i < ret.size(); ++i)
ret.a[i] = (ll)ret[i] * k1 % mod;
ret = ret.exp(m - t * k1), t *= k1, ret.resize(m);
for (int i = m - 1; i >= t; --i)
ret.a[i] = (ll)ret[i - t] * v % mod;
for (int i = 0; i < t; ++i)
ret.a[i] = 0;
return ret;
}
};
}
using Poly::fac;
using Poly::ifac;
using Poly::inv;
using Poly::init;
using Poly::poly;
inline int binom(int n, int m) {
return n < m || m < 0 ? 0 : (ll)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
inline poly czt(const poly &f, int c, int m) {
int n = f.size(), ci = qpow(c, mod - 2);
poly a, b, ret;
a.resize(n), b.resize(n + m - 1), ret.resize(m);
vector<int> cpow(n + m - 1), cipow(max(n, m));
cpow[0] = 1;
for (int i = 1, pw = 1; i < cpow.size(); ++i)
cpow[i] = (ll)cpow[i - 1] * pw % mod,
pw = (ll)pw * c % mod;
cipow[0] = 1;
for (int i = 1, pw = 1; i < cipow.size(); ++i)
cipow[i] = (ll)cipow[i - 1] * pw % mod,
pw = (ll)pw * ci % mod;
for (int i = 0; i < n; ++i)
a.a[i] = (ll)f[i] * cipow[i] % mod;
for (int i = 0; i < n + m - 1; ++i)
b.a[n + m - 2 - i] = cpow[i];
a *= b;
for (int i = 0; i < m; ++i)
ret.a[i] = (ll)cipow[i] * a[n + m - 2 - i] % mod;
return ret;
}
namespace QAnalog {
const int q = 2;
int qPow[N + 5];
int n[N + 5], fac[N + 5], ifac[N + 5];
inline void init() {
qPow[0] = 1;
for (int i = 1; i <= N; ++i)
qPow[i] = (ll)qPow[i - 1] * q % mod;
int qi = qpow(reduce(1 - q), mod - 2);
fac[0] = 1;
for (int i = 1; i <= N; ++i)
n[i] = (ll)(1 - qPow[i] + mod) * qi % mod,
fac[i] = (ll)fac[i - 1] * n[i] % mod;
ifac[N] = qpow(fac[N], mod - 2);
for (int i = N; i; --i)
ifac[i - 1] = (ll)ifac[i] * n[i] % mod;
}
inline int binom(int n, int m) {
return n < m || m < 0 ? 0 : (ll)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
}
int n, H, L[N + 5], W[N + 5], R, lim;
int r[N + 5], coe = 1;
int f[N + 5], g[N + 5], ans;
namespace TellegensPrinciple {
poly seg[N * 4 + 5];
int ans[N + 5];
inline poly mulT(poly a, poly b, int k = -1) {
if (a.empty() || b.empty())
return poly();
int n = a.size(), m = b.size();
if (k == -1)
k = n - m + 1;
if (k <= 0)
return poly();
if (k < 40 || b.size() < 40) {
poly ret;
ret.resize(k);
for (int i = 0;i < k;++i)
for (int j = 0;j < b.size();++j)
fam(ret.a[i], a[i + j], b[j]);
return ret;
}
reverse(b.a.begin(), b.a.end());
if (k == n - m + 1) {
int lim = 1;
for (; lim < n; lim <<= 1);
a.resize(lim), b.resize(lim);
a.ntt(), b.ntt();
for (int i = 0; i < lim; ++i)
a.a[i] = (ll)a[i] * b[i] % mod;
a.ntt(-1);
for (int i = 0; i < k; ++i)
a.a[i] = a[m - 1 + i];
for (int i = k; i < lim; ++i)
a.a[i] = 0;
a.resize(n - m + 1);
} else {
a *= b, a.resize(n + m - 1);
for (int i = 0; i < k; ++i)
a.a[i] = a[m - 1 + i];
for (int i = k; i < n + m - 1; ++i)
a.a[i] = 0;
a.resize(k);
}
return a;
}
void build(int p, int l, int r) {
if (l == r) {
seg[p] = poly({(ll)neg(l - 1) * inv[l] % mod, inv[l]});
return ;
}
int mid = l + r >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
seg[p] = seg[ls] * seg[rs];
}
void solve(int p, int l, int r, poly f) {
if (l == r) {
for (int i = 0; i < f.size(); ++i)
fam(ans[l], seg[p][i], f[i]);
ans[l] = (ll)ans[l] * fac[l] % mod;
return ;
}
int mid = l + r >> 1;
solve(ls, l, mid, f.modxn(mid - l + 2)), solve(rs, mid + 1, r, mulT(f, seg[ls]));
}
void solve(poly f) {
int n = f.size() - 1;
build(1, 1, n), solve(1, 1, n, f);
}
}
int main() {
init(), QAnalog::init();
scanf("%d%d", &n, &H);
for (int i = 1; i <= H; ++i)
scanf("%d", r + i), R = max(R, r[i]),
coe = (ll)coe * qpow(2, (ll)r[i] * (r[i] - 1) / 2 % (mod - 1)) % mod * QAnalog::fac[r[i]] % mod;
f[R] = 1;
for (int i = 1; i <= H; ++i)
f[R] = (ll)f[R] * QAnalog::binom(R, r[i]) % mod;
for (int i = 1; i <= n; ++i)
scanf("%d%d", L + i, W + i), lim = max({lim, L[i], W[i]});
poly temp;
temp.resize(R + 1);
for (int i = 1; i <= H; ++i)
add(temp.a[r[i]], 1);
temp = czt(temp, inv[2], H + 1), temp.a[0] = 0;
for (int i = 1; i <= H; ++i)
temp.a[i] = (ll)(mod - inv[i]) * temp[i] % mod;
temp = temp.exp(H + 1);
temp = czt(temp, 2, lim + 1);
for (int i = R + 1; i <= lim; ++i)
f[i] = (ll)f[i - 1] * qpow(reduce(1 - QAnalog::qPow[i]), H) % mod * qpow(temp[i], mod - 2) % mod;
TellegensPrinciple::solve(poly(vector<int>(f, f + lim + 1)));
copy(TellegensPrinciple::ans + 1, TellegensPrinciple::ans + lim + 1, g + 1);
for (int i = 1; i <= n; ++i) {
ans = (ll)coe * g[L[i]] % mod * g[W[i]] % mod;
printf("%d\n", ans);
}
}
考虑 \(\left\lfloor\sqrt{\frac nk}\right\rfloor\) 有多少种不同的取值。
容易发现以 \(n^{1/3}\) 为分界线可知有 \(O(n^{1/3})\) 种。
接下来处理计算 \(\mu^2\) 的前缀和。易知我们需要所有 \(\left\lfloor\frac n{x^2}\right\rfloor\) 处的 \(\mu^2\) 的前缀和。
注意到 \[\sum_{i=1}^n \mu^2(i) = \sum_{i=1}^{\lfloor\sqrt n\rfloor} \mu(i) \left\lfloor\frac n{i^2}\right\rfloor\]
先不管怎么求 \(\mu\)。考虑设阈值 \(B\):对 \(\left\lfloor\frac n{x^2}\right\rfloor \le B\),线性筛;对 \(\left\lfloor\frac n{x^2}\right\rfloor > B\),使用 \(O\left(\left\lfloor\frac n{x^2}\right\rfloor^{1/3}\right)\) 的整除分块计算。
后面的复杂度大约是 \[\begin{aligned}&\quad\; \int_0^{ \sqrt{\frac nB} } \left(\frac n{x^2}\right)^{1/3} \,\mathrm d x \\&= \left.3 n^{1/3} x^{1/3}\right|_{x=0}^{ \sqrt{\frac nB} } \\&= O(n^{1/2} B^{-1/6})\end{aligned}\]
平衡一下,取 \(B = n^{3/7}\) 可得复杂度为 \(O(n^{3/7})\)。
然后考虑求 \(\mu\)。我们需要的位置是 \(\left\lfloor\sqrt{ \frac n{x^2y} }\right\rfloor = \left\lfloor\frac{\left\lfloor\sqrt{\frac ny}\right\rfloor}x\right\rfloor\)。
设阈值 \(T\),\(\le T\) 的部分用线性筛,\(> T\) 的部分用杜教筛,那么后面部分的复杂度大约是 \[\begin{aligned}&\quad\; \int_0^{ \frac n{T^2} } \left(\frac ny\right)^{1/3} \mathrm dy \\&= \left.\frac 32 n^{1/3} y^{2/3} \right|_{y=0}^{ \frac n{T^2} } \\&= O(n T^{-4/3})\end{aligned}\]
平衡一下,取 \(T = n^{3/7}\) 可得复杂度为 \(O(n^{3/7})\)。
然后乱写一下就过了,虽然常数贼大。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
using ll = long long;
using namespace std;
const int LIM = 9e6;
ll n, m;
int lim;
int vis[LIM + 5], cnt, prime[LIM + 5];
int mu[LIM + 5], mu2[LIM + 5];
unordered_map<ll, ll> memMu, memMu2;
ll sumMu(ll n) {
if (n <= lim)
return mu[n];
if (memMu.count(n))
return memMu[n];
ll ret = 1;
for (ll l = 2, r; l <= n; l = r + 1) {
r = n / (n / l);
ret -= (r - l + 1) * sumMu(n / l);
}
return memMu[n] = ret;
}
ll sumMu2(ll n) {
if (n <= lim)
return mu2[n];
if (memMu2.count(n))
return memMu2[n];
ll ret = 0;
int nSqrt = sqrt(n);
for (int l = 1, r; l <= nSqrt; l = r + 1) {
ll v = n / ((ll)l * l);
r = sqrt(n / v);
ret += (sumMu(r) - sumMu(l - 1)) * v;
}
return memMu2[n] = ret;
}
ll ans;
int main() {
scanf("%lld%lld", &n, &m);
if (n > m)
swap(n, m);
lim = pow(m, 3.0 / 7);
mu[1] = mu2[1] = 1;
for (int i = 2; i <= lim; ++i) {
if (!vis[i])
prime[++cnt] = i, mu[i] = -1;
for (int j = 1; j <= cnt && i * prime[j] <= lim; ++j) {
vis[i * prime[j]] = 1;
if (!(i % prime[j]))
break;
mu[i * prime[j]] = -mu[i];
}
mu2[i] = mu2[i - 1] + (bool)mu[i],
mu[i] += mu[i - 1];
}
for (ll l = 1, r; l <= n; l = r + 1) {
int v0 = sqrt(n / l), v1 = sqrt(m / l);
r = min(n / ((ll)v0 * v0), m / ((ll)v1 * v1));
ans += (sumMu2(r) - sumMu2(l - 1)) * v0 * v1;
}
printf("%lld\n", ans);
}
考虑按以下方式行动必然不劣:
对于路径上每个点,以其为根,首先穿到不在路径上的点的子树中移动(并且这些边最多往返各经过一次),然后回到路径上走到下一个点。
考虑求出每个点为根时在各个子树内移动的最小代价,查询时减掉路径上的相邻点的贡献即可。
记其为 \(f_u\)。
不妨先计算 \(g_u\) 为任选一根时不考虑 \(u\) 的当前根向子树的答案。那么有显然转移 \[\newcommand\fa{ \operatorname{fa} }\newcommand\lca{ \operatorname{lca} }\newcommand\child{ \operatorname{child} }\newcommand\path{ \operatorname{path} }g_u = a_u + \sum_{v\in\child(u)} \max\{0,g_v - z(u,v) - z(v,u)\}\]
对于 \(u\) 和 \(v\in\child(u)\),考虑按换根套路计算 \(f_v\),无非是 \(g_v+\max\{0, f_u-\max\{0,g_v-z(u,v)-z(v,u)\}-z(u,v)-z(v,u)\}\)。
对于询问 \((x,y)\),先计算路径上的边权和,然后剩下的部分就是 \[\begin{aligned}&\quad\;f_{\lca(x,y)} + \sum_{ u\in\path(x,y)\setminus\{\lca(x,y)\} } (g_u-\max\{0,g_u-z(u,\fa(u))-z(\fa(u),u)\}) \\&=f_{\lca(x,y)} + \sum_{ u\in\path(x,y)\setminus\{\lca(x,y)\} } \min\{g_u,z(u,\fa(u))+z(\fa(u),u)\}\end{aligned}\]
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
using ll = long long;
using namespace std;
const int N = 2e5;
const int LG = 17;
int n, q;
vector<pair<int, int>> e[N + 5];
int a[N + 5], w[2][N + 5];
int fa[N + 5], dep[N + 5], faBinary[LG + 5][N + 5];
ll f[N + 5], g[N + 5], wSum[2][N + 5];
ll sum[N + 5];
void dfs1(int u) {
g[u] = a[u], wSum[0][u] = wSum[0][fa[u]] + w[0][u];
faBinary[0][u] = fa[u];
for (int i = 1; i <= LG; ++i)
faBinary[i][u] = faBinary[i - 1][faBinary[i - 1][u]];
for (auto one: e[u]) {
int v = one.first, z = one.second;
if (v != fa[u])
fa[v] = u, dep[v] = dep[u] + 1, w[0][v] = z, dfs1(v), g[u] += max(0LL, g[v] - w[0][v] - w[1][v]);
else
w[1][u] = z;
}
}
void dfs2(int u) {
wSum[1][u] = wSum[1][fa[u]] + w[1][u];
sum[u] = sum[fa[u]] + min(g[u], (ll)w[0][u] + w[1][u]);
for (auto one: e[u]) {
int v = one.first;
if (v != fa[u])
f[v] = g[v] + max(0LL, f[u] - max(0LL, g[v] - w[0][v] - w[1][v]) - w[0][v] - w[1][v]), dfs2(v);
}
}
inline int lca(int x, int y) {
if (dep[x] < dep[y])
swap(x, y);
for (int i = LG; i >= 0; --i)
if (faBinary[i][x] && dep[faBinary[i][x]] >= dep[y])
x = faBinary[i][x];
if (x == y)
return x;
for (int i = LG; i >= 0; --i)
if (faBinary[i][x] != faBinary[i][y])
x = faBinary[i][x], y = faBinary[i][y];
return fa[x];
}
ll ans;
int main() {
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; ++i)
scanf("%d", a + i);
for (int i = 2; i <= n; ++i) {
int u, v, w0, w1;
scanf("%d%d%d%d", &u, &v, &w0, &w1);
e[u].emplace_back(v, w0), e[v].emplace_back(u, w1);
}
dfs1(1), f[1] = g[1], dfs2(1);
for (int x, y; q; --q) {
scanf("%d%d", &x, &y);
int anc = lca(x, y);
ans = f[anc];
ans += sum[x] + sum[y] - 2 * sum[anc];
ans -= wSum[1][x] + wSum[0][y] - wSum[0][anc] - wSum[1][anc];
printf("%lld\n", ans);
}
}
这样的转移很阴间。不过这样的对「路径下方所挂的点」的求和,提示我们作树上差分。
设 \[f_u = f'_u - \sum_{v\in \child(u)} f'_v\]
这样,不难验证转移会变成 \[f_u = \max\{0\} \bigcup \left\{\left.w - \sum_{ \fa(v) \in \path(x,y) \setminus\{u\} } f_v\,\right|\,(x,y,w) \in S \land \lca(x,y)=u\right\}\]
于是可以使用树状数组维护单点加链求和来做到 \(O((n+m)\log n)\)。
接下来考虑计算 \(f(x,y)\)。观察到,若树的根在 \(\path(x,y)\) 上,那么 \[f(x,y) = \sum_{u\in\path(x,y)} f_u\]
也就是说问题变成了换根 DP。
设 \(g_u\) 为根在 \(u\) 子树内时 \(\fa(u)\) 的 DP 值,\(h_u\) 为根为 \(u\) 时 \(u\) 的 DP 值。
计算 \(g_u\) 时,考虑一条经过 \(\fa(u)\) 而不经过 \(u\) 的路径 \((x,y,w)\),从 \(\fa(u)\) 出发向下的一段取 \(f\),\(\fa(u)\) 到 \(\lca(x,y)\) 的一段取 \(g\),\(\lca(x,y)\) 往下的另一段也取 \(f\)。
计算 \(h_u\) 时类似,不过考虑的是经过 \(u\) 的路径。
考虑继续在 \(u\) 处枚举 \(\lca(x,y) = u\) 的路径 \((x,y,w)\),我们希望对路径上的点都挂上一个贡献,同时 \(u\) 处要特殊处理,因为有儿子的要求。
进一步,对于 \(u\) 枚举 \(v \in \child(u)\),考虑端点在 \(u\) 子树内而不在 \(v\) 子树内的路径,这样的路径就是对 \(g_v\) 有贡献的。对于 \(h_u\) 则更简单。
于是我们不妨用线段树维护 DFS 序来计算这部分转移。
\(u\) 处的特殊贡献的话,也可以直接对所有儿子建立线段树(或者如果足够闲,甚至可以通过差分和堆计算),这样不转移 \(O(1)\) 个儿子可以转化为对 \(O(1)\) 个区间取 \(\max\),然后单点查询。
时间复杂度 \(O((n+m)\log n)\)。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
using ll = long long;
using namespace std;
const int mod = 998244353;
const int N = 3e5;
const ll inf = 0x3f3f3f3f3f3f3f3f;
int n, m;
tuple<int, int, int> path[N + 5];
tuple<int, int, int> pathLCA[N + 5];
vector<int> e[N + 5], ch[N + 5];
int label[N + 5];
int fa[N + 5], dep[N + 5], size[N + 5], son[N + 5], top[N + 5], id[N + 5], rk[N + 5];
void dfs(int u) {
static int tot = 0;
rk[id[u] = ++tot] = u, size[u] = 1;
for (int v: e[u])
if (v != fa[u]) {
fa[v] = u, dep[v] = dep[u] + 1, dfs(v), size[u] += size[v];
if (!son[u] || size[son[u]] < size[v])
son[u] = v;
label[v] = ch[u].size(), ch[u].emplace_back(v);
}
}
inline tuple<int, int, int> lca(int x, int y) {
int xson = 0, yson = 0;
while (top[x] != top[y])
if (dep[top[x]] > dep[top[y]])
xson = top[x], x = fa[top[x]];
else
yson = top[y], y = fa[top[y]];
if (dep[x] < dep[y])
yson = son[x];
else if(dep[x] > dep[y])
xson = son[y];
return {dep[x] < dep[y] ? x : y, xson, yson};
}
vector<int> pathsThrough[N + 5];
struct BinaryIndexedTree {
ll c[N + 5];
inline int lowbit(int x) {
return x & -x;
}
inline void update(int x, ll k) {
for (; x <= n; x += lowbit(x))
c[x] += k;
}
inline void update(int l, int r, ll k) {
update(l, k), update(r + 1, -k);
}
inline ll query(int x) {
ll ret = 0;
for (; x; x -= lowbit(x))
ret += c[x];
return ret;
}
} bit;
struct SegmentTree {
ll seg[N * 4 + 5];
};
struct SegmentTree_RangeUpdate: SegmentTree {
int n;
void build(int p, int tl, int tr) {
if (tl > tr)
return ;
seg[p] = 0;
if (tl == tr)
return ;
int mid = tl + tr >> 1;
build(ls, tl, mid), build(rs, mid + 1, tr);
}
void build(int m) {
build(1, 0, (n = m) - 1);
}
void update(int l, int r, ll k, int p, int tl, int tr) {
if (l > r)
return ;
if (l <= tl && tr <= r) {
seg[p] = max(seg[p], k);
return ;
}
int mid = tl + tr >> 1;
if (l <= mid)
update(l, r, k, ls, tl, mid);
if (r > mid)
update(l, r, k, rs, mid + 1, tr);
}
void update(int l, int r, ll k) {
update(l, r, k, 1, 0, n - 1);
}
ll query(int x, int p, int tl, int tr) {
if (tl == tr)
return seg[p];
int mid = tl + tr >> 1;
return max(seg[p], x <= mid ? query(x, ls, tl, mid) : query(x, rs, mid + 1, tr));
}
ll query(int x) {
return query(x, 1, 0, n - 1);
}
} seg0;
struct SegmentTree_RangeQuery: SegmentTree {
int n;
void build(int p, int tl, int tr) {
if (tl > tr)
return ;
seg[p] = -inf;
if (tl == tr)
return ;
int mid = tl + tr >> 1;
build(ls, tl, mid), build(rs, mid + 1, tr);
}
void build(int m) {
build(1, 1, n = m);
}
void insert(int x, ll k, int p, int tl, int tr) {
seg[p] = max(seg[p], k);
if (tl == tr)
return ;
int mid = tl + tr >> 1;
if (x <= mid)
insert(x, k, ls, tl, mid);
else
insert(x, k, rs, mid + 1, tr);
}
void insert(int x, ll k) {
insert(x, k, 1, 1, n);
}
ll query(int l, int r, int p, int tl, int tr) {
if (l > r)
return -inf;
if (l <= tl && tr <= r)
return seg[p];
int mid = tl + tr >> 1;
ll ret = -inf;
if (l <= mid)
ret = max(ret, query(l, r, ls, tl, mid));
if (r > mid)
ret = max(ret, query(l, r, rs, mid + 1, tr));
return ret;
}
ll query(int l, int r) {
return query(l, r, 1, 1, n);
}
} seg1;
ll f[N + 5], fSum[N + 5];
ll g[N + 5], gSum[N + 5];
ll h[N + 5];
int ans;
int main() {
scanf("%d%d", &n, &m);
for (int i = 2; i <= n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
e[u].emplace_back(v), e[v].emplace_back(u);
}
dfs(1);
for (int i = 1; i <= n; ++i) {
int u = rk[i];
top[u] = u == son[fa[u]] ? top[fa[u]] : u;
}
for (int i = 1; i <= m; ++i) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
path[i] = {u, v, w}, pathsThrough[get<0>(pathLCA[i] = lca(u, v))].emplace_back(i);
}
for (int i = n; i; --i) {
int u = rk[i];
for (int j: pathsThrough[u]) {
int x = get<0>(path[j]), y = get<1>(path[j]), w = get<2>(path[j]);
f[u] = max(f[u], w - bit.query(id[x]) - bit.query(id[y]));
}
bit.update(id[u], id[u] + size[u] - 1, f[u]);
h[u] = f[u];
}
for (int i = 1; i <= n; ++i) {
int u = rk[i];
fSum[u] = fSum[fa[u]] + f[u];
}
seg1.build(n);
for (int i = 1; i <= n; ++i) {
int u = rk[i];
gSum[u] = gSum[fa[u]] + g[u];
seg0.build(ch[u].size());
for (int v: ch[u])
g[v] = max(g[v], max(seg1.query(id[u], id[v] - 1), seg1.query(id[v] + size[v], id[u] + size[u] - 1)) + fSum[u] - gSum[u]);
h[u] = max(h[u], seg1.query(id[u], id[u] + size[u] - 1) + fSum[u] - gSum[u]);
for (int j: pathsThrough[u]) {
int x = get<0>(path[j]), y = get<1>(path[j]), w = get<2>(path[j]);
int xson = get<1>(pathLCA[j]), yson = get<2>(pathLCA[j]);
if (dep[x] > dep[y])
swap(x, y), swap(xson, yson);
if (x == y)
seg0.update(0, ch[u].size() - 1, w);
else if(x == u) {
ll v = w - fSum[y] + gSum[x];
seg1.insert(id[y], v);
v = w - fSum[y] + fSum[u];
seg0.update(0, label[yson] - 1, v), seg0.update(label[yson] + 1, ch[u].size() - 1, v);
} else {
ll v = w - fSum[x] - fSum[y] + fSum[u] + gSum[u];
seg1.insert(id[x], v), seg1.insert(id[y], v);
v = w - fSum[x] - fSum[y] + 2 * fSum[u];
int id0 = label[xson], id1 = label[yson];
if (id0 > id1)
swap(id0, id1);
seg0.update(0, id0 - 1, v), seg0.update(id0 + 1, id1 - 1, v), seg0.update(id1 + 1, ch[u].size() - 1, v);
}
}
for (int v: ch[u])
g[v] = max(g[v], seg0.query(label[v]));
}
for (int i = 1; i <= n; ++i)
ans = (ans + h[i] * n) % mod,
ans = (ans + (f[i] + g[i]) % mod * size[i] % mod * (n - size[i])) % mod;
printf("%d\n", ans);
}
对于参数 \(n,a,b,c\) 和 \(\texttt U,\texttt R\) 执行递归过程 \(f(n,a,b,c,\texttt U,\texttt R)\):表示在 \(y = \frac{ax+b}c\) \((x \in (0, n])\) 上考虑这个问题。
根据欧几里得过程,考虑讨论:
若 \(a \ge c\),考虑如何将 \(a\) 取模 \(c\)。
事实上这就是 \(f(n,a \bmod c,b,c,\texttt U,\texttt U^{\lfloor a/c\rfloor}\texttt R)\)。
若 \(a < c\),考虑如何交换 \(a,c\)。
记 \(m = \left\lfloor\frac{an+b}c\right\rfloor\) 即 \(\texttt U\) 的个数。
如果 \(m = 0\),以 \(\texttt R^n\) 终止递归。
否则,简单计算得,可以通过在开头补上 \(\texttt R^{\lfloor(c-b-1)/a\rfloor}\texttt U\),结尾补上 \(\texttt R^{n-\lfloor(cm-b-1)/a\rfloor}\) 的方式,递归到 \(f(m-1,c,b,a,\texttt R,\texttt U)\)。
以下是我擅自从叉姐代码改出来的的一个板子: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
using ll = long long;
using namespace std;
namespace UniversalEuclidean {
template<class MonoidT>
MonoidT qpow(MonoidT a, ll b) {
MonoidT ret = MonoidT::identity();
for (; b; b >>= 1) {
if (b & 1)
ret = ret * a;
a = a * a;
}
return ret;
}
template<class MonoidT>
MonoidT sum(ll n, ll a, ll b, ll c) {
MonoidT r = MonoidT::R();
MonoidT u = MonoidT::U();
MonoidT prefix = qpow(u, b / c) * r;
MonoidT suffix = MonoidT::identity();
b %= c;
while (true) {
if (a >= c)
r = qpow(u, a / c) * r,
a %= c;
else {
ll m = (a * n + b) / c;
if (!m)
return prefix * qpow(r, n) * suffix;
prefix = prefix * qpow(r, (c - b - 1) / a) * u,
suffix = qpow(r, n - (c * m - b - 1) / a) * suffix;
b = (c - b - 1) % a, n = m - 1;
swap(a, c), swap(u, r);
}
}
}
}
这里是一个 LibreOJ 板子的 AC 代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
using ll = long long;
using namespace std;
const int mod = 998244353;
inline int norm(int x) {
return x >= mod ? x - mod : x;
}
inline int reduce(int x) {
return x < 0 ? x + mod : x;
}
inline int neg(int x) {
return x ? mod - x : 0;
}
inline void add(int &x, int y) {
if ((x += y - mod) < 0)
x += mod;
}
inline void sub(int &x, int y) {
if ((x -= y) < 0)
x += mod;
}
inline void fam(int &x, int y, int z) {
x = (x + (ll)y * z) % mod;
}
inline int qpow(int a, int b) {
int ret = 1;
for (; b; b >>= 1) {
if (b & 1)
ret = (ll)ret * a % mod;
a = (ll)a * a % mod;
}
return ret;
}
struct Matrix {
static const int N = 20;
int a[N][N];
inline Matrix() {
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
a[i][j] = 0;
}
inline int *operator[](int x) {
return a[x];
}
inline const int *operator[](int x) const {
return a[x];
}
inline static Matrix identity() {
Matrix ret;
for (int i = 0; i < N; ++i)
ret[i][i] = 1;
return ret;
}
inline Matrix operator+(const Matrix &o) const {
Matrix ret;
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
ret[i][j] = norm(a[i][j] + o[i][j]);
return ret;
}
inline Matrix operator*(const Matrix &o) const {
Matrix ret;
for (int k = 0; k < N; ++k)
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
fam(ret[i][j], a[i][k], o[k][j]);
return ret;
}
} A, B, ans;
struct Monoid {
Matrix AProd, BProd, sum;
inline Monoid(Matrix a = Matrix(), Matrix b = Matrix(), Matrix c = Matrix()): AProd(a), BProd(b), sum(c) {}
static Monoid identity() {
return Monoid(Matrix::identity(), Matrix::identity());
}
inline static Monoid R() {
return Monoid(A, Matrix::identity(), A);
}
inline static Monoid U() {
return Monoid(Matrix::identity(), B);
}
inline Monoid operator*(const Monoid &o) const {
return Monoid(AProd * o.AProd, BProd * o.BProd, sum + AProd * o.sum * BProd);
}
};
namespace UniversalEuclidean {
template<class MonoidT>
MonoidT qpow(MonoidT a, ll b) {
MonoidT ret = MonoidT::identity();
for (; b; b >>= 1) {
if (b & 1)
ret = ret * a;
a = a * a;
}
return ret;
}
template<class MonoidT>
MonoidT sum(ll n, ll a, ll b, ll c) {
MonoidT r = MonoidT::R();
MonoidT u = MonoidT::U();
MonoidT prefix = qpow(u, b / c);
MonoidT suffix = MonoidT::identity();
b %= c;
while (true) {
if (a >= c)
r = qpow(u, a / c) * r,
a %= c;
else {
ll m = ((__int128)a * n + b) / c;
if (!m)
return prefix * qpow(r, n) * suffix;
prefix = prefix * qpow(r, (c - b - 1) / a) * u,
suffix = qpow(r, n - ((__int128)c * m - b - 1) / a) * suffix;
b = (c - b - 1) % a, n = m - 1;
swap(a, c), swap(u, r);
}
}
}
}
ll a, b, c, n;
int m;
int main() {
scanf("%lld%lld%lld%lld%d", &a, &c, &b, &n, &m);
for (int i = 0; i < m; ++i)
for (int j = 0; j < m; ++j)
scanf("%d", A[i] + j);
for (int i = 0; i < m; ++i)
for (int j = 0; j < m; ++j)
scanf("%d", B[i] + j);
ans = UniversalEuclidean::sum<Monoid>(n, a, b, c).sum;
for (int i = 0; i < m; ++i)
for (int j = 0; j < m; ++j)
printf("%d%c", ans[i][j], " \n"[j == m - 1]);
}
初值 \[F_0 = x\]
这些东西让我联想起了自己对此类复杂的含微分方程的生成函数递推的失败尝试的回忆,让我有些退缩,但是这题真的是这样做(
设二元生成函数 \(\mathcal F(x,t) = \sum\limits_{k\ge 0} F_k(x) t^k\),类似地可以写出 \[\frac{\partial}{\partial x} \mathcal F = 1 + t\mathcal F \cosh x\]
容易解得 \[\mathcal F(x) = \me^{t\sinh x} \int_0^x \me^{-t\sinh u} \d u\]
考察其 \([t^k]\),有 \[F_k(x) = \sum\limits_{j=0}^k \frac{ (\sinh x)^{k-j} }{(k-j)!} \int_0^x \frac{(-\sinh u)^j}{j!} \d u\]
注意到积分 \[\int_0^x \me^{ju} \d u =\begin{cases}\frac{\me^{jx}-1}j, &j\ne 0 \\x, &j=0\end{cases}\]
这告诉我们 \(F_k\) 必然是 \(\me^{cx},x\me^{cx}\) 的线性组合。并且显然有 \(-k \le c \le k\)。
通过一些没什么用的技巧可以获得 \(O(k^2)\) 的解法,但数据范围显然提示我们洞察整式递推。
设 \(F_k(x) = G_k(\sinh x)\),主要是为了简化积分形式,那么根据一开始的递推 \[\begin{aligned}F_k(x)&= \int_0^x G_{k-1}(\sinh u) \cosh x \d u \\&= \int_0^x G_{k-1}(\sinh u) \d(\sinh u) \qquad \left(\frac{\d(\sinh u)}{\d u} = \cosh u\right) \\&= \int_0^{\sinh x} G_{k-1}(v) \d v \\G_k(x)&= \int_0^x G_{k-1}(v) \d v\end{aligned}\]
考察 \(F_0(x)=x\),那么 \(G_0(x)\) 就应该等于反双曲函数 \(\sinh^{-1} x\)。由于其 ODE 形式不够简洁,不是那么容易推导递推式,因此我们考虑从其导数 \((\sinh^{-1} x)' = (1+x^2)^{-1/2}\) 入手。
令其为 \(A(x)\),那么立刻知 \[\begin{aligned}A'(x)(1+x^2) &= - xA(x) \\na_n &= (1-n) a_{n-2}\end{aligned}\]
设 \(B(x) = G_k(x)\),则 \(b_n = \frac{a_{n-k-1}}{ n^{ \underline{k+1} } }\) 即 \(A(x)\) 积分 \(k+1\) 次,代入递推式 \[\begin{aligned}(n-k-1)a_{n-k-1}&= (2-n+k) a_{n-k-3} \\(n-k-1)b_n n^{ \underline{k+1} }&= (2-n+k) b_{n-2} (n-2)^{ \underline{k+1} } \\n(n-1)(n-k-1) b_n&= -(n-k-1)(n-k-2)^2 b_{n-2} \\(n-k-1)[n(n-1)b_n + (n-k-2)^2 b_{n-2}]&= 0\end{aligned}\]
接下来讨论一下,发现 \(n=k+1\) 时 \(b_{n-2}=0,b_n=\frac1{(k+1)!}\),则 \[\begin{aligned}n(n-1)b_n + (n-k-2)^2 b_{n-2}&= \frac{[n=k+1]}{(k-1)!} \\n(n-1)b_n + [(n-2)(n-1)+(1-2k)(n-2)+k^2] b_{n-2}&= \frac{[n=k+1]}{(k-1)!} \\k^2 x^2 B + (1-2k)x^3 B' + x^2(1+x^2)B''&= \frac{ x^{k+1} }{(k-1)!} \\k^2 B + (1-2k)x B' + (1+x^2)B''&= \frac{ x^{k-1} }{(k-1)!}\end{aligned}\]
为了方便计算,我们作换元 \(x = \ln x\),那么需要提取的系数就变成了 \(x^c,x^c \ln x\)。
设 \(X = \frac{x-1/x}2\),令 \(P(x) = B(X)\),为了导出 \(P\) 的 ODE,我们事先作一些计算: \[\begin{aligned}P'(x)&= \left(\frac{ 1+x^{-2} }2\right) B'(X) \\P''(x)&= \left(\frac{ 1+x^{-2} }2\right)^2 B''(X) - x^{-3} B'(x) \\\Rightarrow B'(x)&= \left(\frac2{ 1+x^{-2} }\right) P'(x) \\B''(x)&= \left(\frac2{ 1+x^{-2} }\right)^2 P''(x) + x^{-3}\left(\frac2{ 1+x^{-2} }\right)^3 P'(x)\end{aligned}\]
强行代入再经过一番化简后可得 \[k^2 (1+x^2) P(x) + ((1+2k)x + (1-2k)x^3) P'(x) + ( x^2+x^4) P''(x) = \frac{X^{k-1}}{(k-1)!} \cdot (1+x^2)\]
令 \(g_n\) 为其 \([x^n \ln x]\),可得 \[(n+k)^2 g_n + (n-k-2)^2 g_{n-2} = 0\]
令 \(f_n\) 为其 \([x^n]\),可得 \[(n+k)^2 f_n + (n-k-2)^2 f_{n-2} + 2(n+k)g_n + 2(n-k-2)g_{n-2} = [x^n] \frac{X^{k-1}}{(k-1)!} (1+x^2)\]
考虑边界值,复习一下前面解出的微分方程可知 \[\begin{aligned}f_k &= \frac 1{2^k}\sum_{j=1}^k \frac{(-1)^j}{(k-j)!j!\cdot j} \\g_k & = \frac 1{2^kk!}\end{aligned}\]
如此从大至小递推即可。
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
using namespace std;
const int K = 1e7 + 2;
const int mod = 998244353;
const int inv2 = 499122177;
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int k;
long long n;
int fac[K + 5],ifac[K + 5],inv[K + 5];
int vis[K + 5],cnt,prime[K + 5],pw[K + 5];
int f[K + 5],g[K + 5],ans;
int main()
{
scanf("%d%lld",&k,&n),--k;
fac[0] = 1;
for(register int i = 1;i <= k;++i)
fac[i] = (long long)fac[i - 1] * i % mod;
ifac[k] = fpow(fac[k],mod - 2);
for(register int i = k;i;--i)
ifac[i - 1] = (long long)ifac[i] * i % mod;
for(register int i = 1;i <= k;++i)
inv[i] = (long long)ifac[i] * fac[i - 1] % mod;
pw[0] = n == 1 ? 1 : 0,pw[1] = 1;
for(register int i = 2;i <= k;++i)
{
if(!vis[i])
pw[prime[++cnt] = i] = fpow(i,(n - 1) % (mod - 1));
for(register int j = 1;j <= cnt && i * prime[j] <= k;++j)
{
vis[i * prime[j]] = 1,pw[i * prime[j]] = (long long)pw[i] * pw[prime[j]] % mod;
if(!(i % prime[j]))
break;
}
}
n %= mod;
f[k] = (long long)ifac[k] * fpow(2,mod - 1 - k) % mod;
for(register int i = k - 2;i >= 0;i -= 2)
f[i] = f[i + 2] * (mod - (long long)(i + 2 + k) * (i + 2 + k) % mod) % mod * inv[k - i] % mod * inv[k - i] % mod;
for(register int i = k;i > 0;i -= 2)
ans = (ans + (long long)f[i] * pw[i] % mod * n) % mod;
ans = (ans + (long long)f[0] * pw[0] % mod * n % mod * inv2) % mod;
int t = fpow(2,mod - k);
for(register int i = 0;2 * i <= k - 1;++i)
g[k - 1 - 2 * i] = (long long)(i & 1 ? mod - t : t) * ifac[i] % mod * ifac[k - 1 - i] % mod;
for(register int i = k + 1;i >= 2;i -= 2)
g[i] = (g[i] + g[i - 2]) % mod;
for(register int i = 2;i <= k + 2;++i)
g[i] = (g[i] + (long long)(mod - f[i - 2]) * (mod + 2 * (i - k - 2)) + (long long)(mod - f[i]) * (2 * (i + k))) % mod;
f[k] = 0;
for(register int i = 1;i <= k;++i)
f[k] = (f[k] + (long long)(i & 1 ? mod - inv[i] : inv[i]) * ifac[i] % mod * ifac[k - i]) % mod;
f[k] = (long long)f[k] * fpow(2,mod - 1 - k) % mod;
for(register int i = k - 1;i >= 0;--i)
f[i] = (f[i + 2] * (mod - (long long)(i + 2 + k) * (i + 2 + k) % mod) + g[i + 2]) % mod * inv[k - i] % mod * inv[k - i] % mod;
for(register int i = 1;i <= k;++i)
ans = (ans + (long long)f[i] * pw[i] % mod * i) % mod;
ans = 2 * ans % mod;
printf("%d\n",ans);
}
注意到必然至少存在一个最长交替子序列经过最大值 \(n\),于是考虑枚举 \(n\) 的位置为 \(i\)。
然后首先要从 \(n-1\) 个数中选择 \(i-1\) 个放到左边。
接下来考虑两边分配的交替子序列长度。枚举 \(2r+1+s=k\),那么左边可以有 \(2r\) 或 \(2r+1\) 的交替子序列(因为 \(n\) 可以替换掉 \(2r+1\) 处的),右边可以有 \(s\) 的交替子序列。
因此 \[f_k(n) = \sum\limits_{i=1}^n \binom{n-1}{i-1} \sum\limits_{2r+s=k-1} (f_{2r}(i-1) + f_{2r+1}(i-1))f_s(n-i)\]
设 BGF \(F(x,t) = \sum f_k(n) \frac{x^n}{n!} t^k\),\(F_0(x,t) = \sum f_{2k}(n) \frac{x^n}{n!} t^{2k}\),\(F_1(x,t) = \sum f_{2k+1}(n) \frac{x^n}{n!} t^{2k+1}\)。
那么根据如上递推式,我们有 \[\begin{aligned}\frac{\partial F}{\partial x} &= (tF_0 + F_1) F \\\frac{\partial F_0}{\partial x} &= (tF_0 + F_1) F_1 \\\frac{\partial F_1}{\partial x} &= (tF_0 + F_1) F_2\end{aligned}\]
然后注意到 \[\begin{aligned}\frac{\partial(F_0^2-F_1^2)}{\partial x} &= \frac{\partial F_0^2}{\partial x} - \frac{\partial F_1^2}{\partial x} \\&= 0\end{aligned}\]
这样对于 \(F_0^2 - F_1^2\),我们只需要对照常数。而根据定义仅有 \(F_0\) 在 \(x^0\) 处有 \(1\),因此 \[\begin{aligned}&\quad \; F_0^2 - F_1^2 \\&= 1 \\&= \left(\frac12(F(x,t)+F(x,-t))\right)^2 - \left(\frac12(F(x,t)-F(x,-t)\right)^2 \\&= F(x,t) F(x,-t)\end{aligned}\]
则 \[\begin{aligned}\frac{\partial F}{\partial x} &= (tF_0 + F_1) F \\&= \frac12 ((t+1)F+(t-1)F^{-1}) F \\\frac{\partial F}{((t+1)F+(t-1)F^{-1})F} &= \frac{\partial x}2\end{aligned}\]
设 \(G = \frac F{1-t}\),则 \(F = (1-t)G\)。代入得 \[\frac{\partial G}{(1-t^2)G^2-1} = \frac{\partial x}2\]
令 \(\alpha = \sqrt{1-t^2}\),我们作部分分式分解,就有 \[\begin{aligned}\partial (\alpha G) \left(\frac1{\alpha G-1} - \frac1{\alpha G+1}\right) &= \alpha \partial x \\\ln \frac{\alpha G - 1}{\alpha G + 1} &= \alpha x + C' \\\frac{\alpha G - 1}{\alpha G + 1} &= {\rm e}^{\alpha x+C'} \\&= {\rm e}^{\alpha x} C \\G &= \frac{ 1+C{\rm e}^{\alpha x} }{\alpha(1-C{\rm e}^{\alpha x})}\end{aligned}\]
其中 \(C',C={\rm e}^{C'}\) 关于 \(x\) 都是常数。
然后注意到 \(B(0,t) = \frac1{1-t}\),因此 \[C = \frac{\frac{\alpha}{1-t}-1}{\frac{\alpha}{1-t}+1} = \frac{1-\alpha}t\]
因此 \[G = \frac1{\alpha} \left(\frac2{ 1-C{\rm e}^{\alpha x} }-1\right) = \frac1{\alpha} \left(\frac2{ 1-\frac{1-\alpha}t{\rm e}^{\alpha x} }-1\right)\]
我们显然可以不管 \([x^0]\),接下来尝试提取其系数 \[\begin{aligned}&\quad\; \frac 2{\alpha} \left(1-\frac{1-\alpha}t {\rm e}^{\alpha x}\right)^{-1} \\&= \frac 2{\alpha} \sum\limits_r \left(\frac{1-\alpha}t {\rm e}^{\alpha x} \right)^r \\&= \frac 2{\alpha} \sum\limits_r \left(\frac{1-\alpha}t\right)^r \sum\limits_{s} \frac{(r\alpha x)^s}{s!} \\&= 2 \sum\limits_r t^{-r} \sum\limits_{s} \frac{(rx)^s}{s!} (1-\alpha)^r \alpha^{s-1} \\\end{aligned}\]
在出题人给出的标准做法中,这里凑出了一个有趣的形式,然后利用具体数学中提到的恒等式来辅助推导,但未免太过匪夷所思,这里展示一种更暴力却也更自然的做法(虽然最后推得的结果形式也不尽相同)。
我们想要展开 \[(1-\sqrt{1-t^2})^r (1-t^2)^{(s-1)/2}\]
不妨作换元,令 \(4x=t^2\),则 \(\frac{t^2}4=x\),且原式变为 \[(1-\sqrt{1-4x})^r (1-4x)^{(s-1)/2}\]
令 \(F\) 满足二叉树方程 \(F = x(1+F)^2\) 就有熟知结论 \[F = \frac{ 1-2x-\sqrt{1-4x} }{2x},\quad x = \frac{F}{(1+F)^2}\]
我们把原式整理成复合形式 \[\begin{aligned}(1-\sqrt{1-4x})^r (1-4x)^{(s-1)/2}&= (2x(1+F))^r (1-2x-2xF)^{s-1} \\&= 2^r x^r (1+F)^r (1-2x(1+F))^{s-1} \\&= 2^rF^r (1-F)^{s-1} (1+F)^{-r-s+1}\end{aligned}\]
然后使用另类拉格朗日反演提取系数 \[\begin{aligned}[][x^n] F^r (1-F)^{r-1} (1+F)^{-r-s+1}&= [x^n] x^r (1-x)^{r-1} (1+x)^{-r-s+1} \frac{1-x}{(1+x)^3} (1+x)^{2n+2} \\&= [x^n] x^r (1-x)^r (1+x)^{2n-r-s} \\&= \sum\limits_i (-1)^i \binom ri \binom{2n-r-s}{n-r-i}\end{aligned}\]
代回到原式,就有 \[\begin{aligned}&= 2 \sum\limits_r t^{-r} \sum\limits_{s} \frac{(rx)^s}{s!} 2^r \sum\limits_l \left(\frac{t^2}4\right)^l \sum\limits_i (-1)^i \binom{s}{i} \binom{2l-r-s}{l-r-i} \\&= 2 \sum\limits_r t^{-r} \sum\limits_{s} \frac{(rx)^s}{s!} 2^r \sum\limits_{u} \left(\frac{t^2}4\right)^{r+u} \sum\limits_i (-1)^i \binom{s}{i} \binom{r+2u-s}{u-i} \\&= \sum\limits_{r,s,u} 2^{1-r-2u} t^{r+2u} \frac{(rx)^s}{s!} \sum\limits_i (-1)^i \binom si \binom{r+2u-s}{u-i}\end{aligned}\]
则令 \(n=s,k=r+2u\),便知 \[g_k(n) = 2^{1-k} \sum\limits_{2i+r \le k,r\equiv k \pmod 2} r^n (-1)^i \binom ni \binom{k-n}{\frac{k-r}2-i}\]
枚举 \(r\),则需要计算形如 \[f_t = \sum\limits_i (-1)^i \binom ni \binom{-s}{t-i}\]
的数列。
考虑其生成函数,显然其为 \[F(x) = (1+x)^{-s} (1-x)^n\]
求导 \[F' = \frac{-sF}{1+x} - \frac{nF}{1-x}\]
即 \[tf_t = -(n+s)f_{t-1} + (s-n+t-2)f_{t-2}\]
时间复杂度 \(O(k \log_k n)\)。
感觉这题有意思的地方主要在于最前面的 DP 和微分方程部分,以及后面的提取系数部分。
由于这个题很明显是先想题意再想做法的,所以想复刻这种做法的题目还是比较无从下手(好像暴露了啥
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
using namespace std;
const int K = 1e6;
const int mod = 998244353;
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
long long n;
int k;
int ans;
int inv[K + 5];
int vis[K + 5],cnt,prime[K + 5],pw[K + 5];
int f[K + 5];
int get(long long n,int k)
{
if(k <= 0)
return 0;
int ret = 0;
int s = (n - k) % mod;
f[0] = 1;
for(register int t = 1;t <= (k >> 1);++t)
f[t] = ((2 * mod - n % mod - s) % mod * f[t - 1] + (s + t - n % mod - 2 + mod) % mod * (t >= 2 ? f[t - 2] : 0)) % mod,
f[t] = (long long)inv[t] * f[t] % mod;
for(register int r = k;r >= 0;r -= 2)
ret = (ret + (long long)pw[r] * f[k - r >> 1]) % mod;
ret = (long long)ret * fpow(2,(mod - k) % (mod - 1)) % mod;
return ret;
}
int main()
{
scanf("%lld%d",&n,&k);
pw[1] = 1;
for(register int i = 2;i <= k;++i)
{
if(!vis[i])
pw[i] = fpow(i,n % (mod - 1));
for(register int j = 1;j <= cnt && i * prime[j] <= k;++j)
{
vis[i * prime[j]] = 1,pw[i * prime[j]] = (long long)pw[i] * pw[prime[j]] % mod;
if(!(i % prime[j]))
break;
}
}
inv[1] = 1;
for(register int i = 2;i <= (k >> 1);++i)
inv[i] = (long long)(mod - mod / i) * inv[mod % i] % mod;
ans = (get(n,k) - get(n,k - 1) + mod) % mod;
printf("%d\n",ans);
}
一个连通块有三种情况:环、链、点,接下来我们分别来考虑它们。
对于环,它显然是偶环,那么左右部点的个数是一样的,并且正反各会被统计一次。而一边是排列,另一边是环排列,因此就是 \[\frac 12\left(-\ln(1-xy)-xy\right)\]
对于链,右部点一定比左部点多一个,并且正反各会被统计一次。且两边都是排列,因此 \[\frac 12 \frac{xy^2}{1-xy}\]
对于点,必然是一个右部点,因此 \[y\]
现在将它们组合,那么答案就是 \[\begin{aligned}n!m! [x^n y^m] F(x,y) &= n!m! [x^n y^m] \exp\left(\frac 12\left(-\ln(1-xy)-xy\right)+\frac 12 \frac{xy^2}{1-xy}+y\right) \\&= n!m! [s^n t^{m-n}] \exp\left(\frac 12\left(-\ln(1-s)-s\right)+\frac 12 \frac{st}{1-s}+t\right) \\&= n!m! [s^n t^{m-n}] \exp\left(\frac 12\left(-\ln(1-s)-s\right)\right) {\rm e}^t \sum\limits_{k\ge 0} \frac 1{k!} \left(\frac 12 \frac{st}{1-s}\right)^k \\&= n!m^{\underline n} [s^n] \exp\left(\frac 12\left(-\ln(1-s)-s\right)\right) \sum\limits_{k\ge 0} \binom{m-n}k \left(\frac 12 \frac s{1-s}\right)^k \\&= n!m^{\underline n} [s^n] \exp\left(\frac 12\left(-\ln(1-s)-s\right)\right) \left(1+\frac 12 \frac s{1-s}\right)^{m-n} \\&= n!m^{\underline n} [s^n] \frac1{ \sqrt{1-s} } {\rm e}^{-s/2} \left(\frac{2-s}{2-2s}\right)^{m-n}\end{aligned}\]
而其三者均微分有限,故其乘积也微分有限,借助计算机便可得到递推式 \[f_n = \frac1{2n}\left((m-2n-3) f_{n-1} + (3-n) f_{n-2} - \frac 12 f_{n-3}\right)\]
代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
using namespace std;
const int N = 2e6;
const int mod = 943718401;
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int n,k;
long long m;
int f[N + 5],ans;
int main()
{
scanf("%d%lld",&n,&m),k = (m - n) % mod;
f[0] = 1,
f[1] = 471859201LL * k % mod,
f[2] = (825753601LL * k % mod * k + 589824001LL * k + 707788801) % mod * 2 % mod,
f[3] = (924057601LL * k % mod * k % mod * k + 766771201LL * k % mod * k + 550502401LL * k + 786432001) % mod * 6 % mod;
for(register int i = 4;i <= n;++i)
f[i] = (k + 3LL * (i - 1)) * f[i - 1] % mod,
f[i] = (f[i] + (3LL - i + mod) * (i - 1) % mod * f[i - 2]) % mod,
f[i] = (f[i] + 471859200LL * (i - 1) % mod * (i - 2) % mod * f[i - 3]) % mod,
f[i] = 471859201LL * f[i] % mod;
ans = 1;
for(register int i = 1;i <= n;++i)
ans = (m - i + 1) % mod * ans % mod;
ans = (long long)ans * f[n] % mod;
printf("%d\n",ans);
}