洛谷 3301 「SDOI2013」方程

水题(

\(\ge\) 的限制显然很好处理……从 \(m\) 中减去然后当成正整数即可。
\(\le\) 的限制,拆成两个 \(\ge\) 相减的形式容斥即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <cmath>
#include <cstdio>
using namespace std;
const int N = 16;
const int MX = 1e6;
const int LG = 30;
int T,mod;
int n,n1,n2,m;
int a[N + 5];
int fpow(int a,int b,int mod)
{
int ret = 1;
for(;b;b >>= 1)
(b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
return ret;
}
int exgcd(int a,int b,int &x,int &y,int mod)
{
if(!b)
{
x = 1,y = 0;
return a;
}
int X,Y,ret = exgcd(b,a % b,X,Y,mod);
x = Y,y = X - a / b * Y;
return ret;
}
int inv(int a,int mod)
{
int x,y;
exgcd(a,mod,x,y,mod);
return (x % mod + mod) % mod;
}
int f[MX + 5];
int fac(int n,int p,int mod)
{
if(!n)
return 1;
return (long long)fpow(f[mod - 1],n / mod,mod) * f[n % mod] % mod * fac(n / p,p,mod) % mod;
}
int C(int n,int m,int p,int mod)
{
if(n < m)
return 0;
f[0] = 1;
for(register int i = 1;i < mod;++i)
(i % p) ? (f[i] = (long long)f[i - 1] * i % mod) : (f[i] = f[i - 1]);
int cnt = 0;
for(register int i = n;i;i /= p)
cnt += i / p;
for(register int i = m;i;i /= p)
cnt -= i / p;
for(register int i = n - m;i;i /= p)
cnt -= i / p;
return (long long)fac(n,p,mod) * inv(fac(m,p,mod),mod) % mod * inv(fac(n - m,p,mod),mod) % mod * fpow(p,cnt,mod) % mod;
}
int C(int n,int m,int mod)
{
int bound = sqrt(mod);
int a[LG + 5],c[LG + 5],tot = 0;
int M = 1,ans = 0;
for(register int i = 2;i <= bound && mod > 1;++i)
{
int prod = 1;
for(;!(mod % i);mod /= i,prod *= i);
prod > 1 && (a[++tot] = C(n,m,i,prod),c[tot] = prod);
}
mod > 1 && (a[++tot] = C(n,m,mod,mod),c[tot] = mod);
for(register int i = 1;i <= tot;++i)
M *= c[i];
for(register int i = 1;i <= tot;++i)
ans = (ans + (long long)a[i] * (M / c[i]) % M * inv(M / c[i],c[i])) % M;
return ans;
}
int coe = 1,ans;
void dfs(int k)
{
if(k > n1)
{
ans = (ans + (long long)coe * C(m - 1,n - 1,mod)) % mod;
return ;
}
dfs(k + 1),m -= a[k],coe = (long long)coe * (mod - 1) % mod,dfs(k + 1),m += a[k],coe = (long long)coe * (mod - 1) % mod;
}
int main()
{
scanf("%d%d",&T,&mod);
for(;T;--T)
{
scanf("%d%d%d%d",&n,&n1,&n2,&m),ans = 0;
for(register int i = 1;i <= n1 + n2;++i)
scanf("%d",a + i),i > n1 && (m -= a[i] - 1);
dfs(1);
printf("%d\n",ans);
}
}