LibreOJ 3120 「CTS2019」珍珠

设颜色为 \(i\) 的珍珠有 \(c_i\) 个,则 \[ \begin{aligned} \sum\limits_{i=1}^D \left\lfloor\frac{c_i}2\right\rfloor &\ge m \\ \sum\limits_{i=1}^D \frac{c_i - c_i \bmod 2}2 &\ge m \\ \sum\limits_{i=1}^D c_i \bmod 2 &\le n - 2m \end{aligned} \]

\(n-2m \ge D\),答案为 \(D^n\)
\(n-2m < 0\),答案为 \(0\)

首先讲一个垃圾做法。
\(f_i\) 表示恰有 \(i\) 个颜色为奇数的方案数,\(g_i\) 表示钦点 \(i\) 个为奇数的方案数,则 \[ g_i = \sum\limits_{j=i}^D \binom ji f_j \Longleftrightarrow f_i = \sum\limits_{j=i}^D (-1)^{j-i} \binom ji g_j \]

后者显然可以卷积处理。考虑计算 \(g_i\)

根据基本的 EGF 知识,有 \[ \def\e{ {\rm e} } \begin{aligned} g_i &= \binom Di n![x^n] \left(\frac{\e^x-\e^{-x}}2\right)^i (\e^x)^{D-i} \\ &= \binom Di \frac{n!}{2^i}[x^n] \sum\limits_{j=0}^i \binom ij \e^{jx} (-\e^{-x})^{i-j} \e^{(D-i)x} \\ &= \binom Di \frac{n!}{2^i} \sum\limits_{j=0}^i \binom ij (-1)^{i-j} [x^n] \e^{(D-2i+2j)x} \\ &= \binom Di \frac{n!}{2^i} \sum\limits_{j=0}^i \binom ij (-1)^{i-j} \frac{(D-2i+2j)^n}{n!} \\ &= \frac{D!}{2^i(D-i)!} \sum\limits_{j=0}^i \frac{1}{j!} \cdot (-1)^{i-j} \frac{(D-2i+2j)^n}{(i-j)!} \end{aligned} \]

两次卷积即可计算答案。时间复杂度 \(O(D \log D)\)

为什么说这个做法垃圾?
真的有必要二项式反演吗?
实际上直接枚举奇数的个数并不难计算。
再进行各种生成函数推导不难得到一个 \(O(D \log_D n)\) 做法,可以近似认为是 \(O(D)\)
这里就不写了,因为鸽(

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define add(a,b) (a + b >= mod ? a + b - mod : a + b)
#define dec(a,b) (a < b ? a - b + mod : a - b)
using namespace std;
const int N = 1e5;
const int mod = 998244353;
inline int fpow(int a,int b)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int D;
long long n,m;
namespace Poly
{
const int N = 1 << 18;
const int G = 3;
int lg2[N + 5];
int rev[N + 5],fac[N + 5],ifac[N + 5],inv[N + 5];
int rt[N + 5],irt[N + 5];
inline void init()
{
for(register int i = 2;i <= N;++i)
lg2[i] = lg2[i >> 1] + 1;
int w = fpow(G,(mod - 1) / N);
rt[N >> 1] = 1;
for(register int i = (N >> 1) + 1;i <= N;++i)
rt[i] = (long long)rt[i - 1] * w % mod;
for(register int i = (N >> 1) - 1;i;--i)
rt[i] = rt[i << 1];
fac[0] = 1;
for(register int i = 1;i <= N;++i)
fac[i] = (long long)fac[i - 1] * i % mod;
ifac[N] = fpow(fac[N],mod - 2);
for(register int i = N;i;--i)
ifac[i - 1] = (long long)ifac[i] * i % mod;
for(register int i = 1;i <= N;++i)
inv[i] = (long long)ifac[i] * fac[i - 1] % mod;
}
struct poly
{
vector<int> a;
inline poly(int x = 0)
{
x && (a.push_back(x),1);
}
inline poly(const vector<int> &o)
{
a = o,shrink();
}
inline void shrink()
{
for(;!a.empty() && !a.back();a.pop_back());
}
inline int size() const
{
return a.size();
}
inline void resize(int x)
{
a.resize(x);
}
inline int operator[](int x) const
{
if(x < 0 || x >= size())
return 0;
return a[x];
}
inline int &operator[](int x)
{
return a[x];
}
inline void clear()
{
vector<int>().swap(a);
}
inline void ntt(int type = 1)
{
int n = size();
type == -1 && (reverse(a.begin() + 1,a.end()),1);
int lg = lg2[n] - 1;
for(register int i = 0;i < n;++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg),
i < rev[i] && (swap(a[i],a[rev[i]]),1);
for(register int w = 2,m = 1;w <= n;w <<= 1,m <<= 1)
for(register int i = 0;i < n;i += w)
for(register int j = 0;j < m;++j)
{
int t = (long long)rt[m | j] * a[i | j | m] % mod;
a[i | j | m] = dec(a[i | j],t),a[i | j] = add(a[i | j],t);
}
if(type == -1)
for(register int i = 0;i < n;++i)
a[i] = (long long)a[i] * inv[n] % mod;
}
friend inline poly operator+(const poly &a,const poly &b)
{
vector<int> ret(max(a.size(),b.size()));
for(register int i = 0;i < ret.size();++i)
ret[i] = add(a[i],b[i]);
return poly(ret);
}
friend inline poly operator-(const poly &a,const poly &b)
{
vector<int> ret(max(a.size(),b.size()));
for(register int i = 0;i < ret.size();++i)
ret[i] = dec(a[i],b[i]);
return poly(ret);
}
friend inline poly operator*(poly a,poly b)
{
if(a.a.empty() || b.a.empty())
return poly();
int lim = 1,tot = a.size() + b.size() - 1;
for(;lim < tot;lim <<= 1);
a.resize(lim),b.resize(lim);
a.ntt(),b.ntt();
for(register int i = 0;i < lim;++i)
a[i] = (long long)a[i] * b[i] % mod;
a.ntt(-1),a.shrink();
return a;
}
poly &operator+=(const poly &o)
{
resize(max(size(),o.size()));
for(register int i = 0;i < o.size();++i)
a[i] = add(a[i],o[i]);
return *this;
}
poly &operator-=(const poly &o)
{
resize(max(size(),o.size()));
for(register int i = 0;i < o.size();++i)
a[i] = dec(a[i],o[i]);
return *this;
}
poly &operator*=(poly o)
{
return (*this) = (*this) * o;
}
poly deriv() const
{
if(a.empty())
return poly();
vector<int> ret(size() - 1);
for(register int i = 0;i < size() - 1;++i)
ret[i] = (long long)(i + 1) * a[i + 1] % mod;
return poly(ret);
}
poly integ() const
{
if(a.empty())
return poly();
vector<int> ret(size() + 1);
for(register int i = 0;i < size();++i)
ret[i + 1] = (long long)a[i] * inv[i + 1] % mod;
return poly(ret);
}
inline poly modxn(int n) const
{
n = min(n,size());
return poly(vector<int>(a.begin(),a.begin() + n));
}
inline poly inver(int m) const
{
poly ret(fpow(a[0],mod - 2));
for(register int k = 1;k < m;)
k <<= 1,ret = (ret * (2 - modxn(k) * ret)).modxn(k);
return ret.modxn(m);
}
inline poly log(int m) const
{
return (deriv() * inver(m)).integ(),modxn(m);
}
inline poly exp(int m) const
{
poly ret(1);
for(register int k = 1;k < m;)
k <<= 1,ret = (ret * (1 - ret.log(k) + modxn(k))).modxn(k);
return ret.modxn(m);
}
};
}
using Poly::init;
using Poly::poly;
poly f,g;
int ans;
int main()
{
Poly::init();
scanf("%d%lld%lld",&D,&n,&m);
if(n - 2 * m >= D)
{
printf("%d\n",fpow(D,n % (mod - 1)));
return 0;
}
else if(n < 2 * m)
{
puts("0");
return 0;
}
f.resize(D + 1),g.resize(D + 1);
for(register int i = 0;i <= D;++i)
f[i] = Poly::ifac[i],
g[i] = (long long)(i & 1 ? mod - 1 : 1) * fpow((D - 2 * i + mod) % mod,n % (mod - 1)) % mod * Poly::ifac[i] % mod;
g *= f,g.resize(D + 1);
for(register int i = 0;i <= D;++i)
g[i] = (long long)g[i] * Poly::fac[D] % mod * Poly::ifac[D - i] % mod * fpow(2,mod - 1 - i) % mod * Poly::fac[i] % mod;
reverse(g.a.begin(),g.a.end());
for(register int i = 0;i <= D;++i)
f[i] = (long long)(i & 1 ? mod - 1 : 1) * Poly::ifac[i] % mod;
f *= g,f.resize(D + 1),reverse(f.a.begin(),f.a.end());
for(register int i = 0;i <= D;++i)
f[i] = (long long)f[i] * Poly::ifac[i] % mod;
for(register int i = 0;i <= n - 2 * m;++i)
ans = add(ans,f[i]);
printf("%d\n",ans);
}