JZOJ 6977.去南极

又是 dirty works……(

首先有 DP \[ f_n = s_n + \frac 1{\binom n3} \sum\limits_{i=1}^n (i-1)(n-i)(f_k + f_{n-k}) \]

随便化一下可得 \[ \binom n3 f_n = \binom n3 s_n + 2 \sum\limits_{i=1}^n (n-i-1)i f_i \]

写成生成函数的形式就是 \[ \begin{aligned} F^{(3)}(x) &= S^{(3)}(x) + \frac{12}{(1-x)^2}F'(x) \\ (1-x)^3 F^{(3)} - 12(1-x)F' &= S^{(3)} \end{aligned} \]

我也不知道为什么定义算子 \[ \theta F(x) = (1-x)F'(x) \]

(注意这里的定义和题解不一样)

然后可以发现 \[ \begin{aligned} \theta^2 F &= (1-x)^2F'' - (1-x)F' \\ \theta^3 F &= (1-x)^3F^{(3)} - 3(1-x)^2F'' + (1-x)F' \end{aligned} \]

因此 \[ \begin{aligned} (\theta^3+3\theta^2-10\theta)F &= S^{(3)} \\ \theta(\theta-2)(\theta+5)F &= (1-x)^3 S^{(3)} \end{aligned} \]

\(Q = (\theta + 5)F,P = (\theta - 2)Q\),则有 \[ \begin{cases} \theta P = (1-x)^3 S^{(3)} \\ (\theta - 2) Q = P \\ (\theta + 5) F = Q \end{cases} \]

提取系数得 \[ \begin{cases} (n+1)p_{n+1} = np_n + \triangledown^3 s_{n+3} (n+3)^{\underline 3} \\ (n+1)q_{n+1} = (n+2) q_n + p_n \\ (n+1)f_n = (n-5)f_n + q_n \end{cases} \]

从第一条式子有 \[ np_n = \triangledown^2 s_{n+2} (n+2)^{\underline 3} \]

再代入第二条式子,展开得 \[ \frac{ q_{n+1} }{n+2} = \frac{q_{m+1}}{m+2} + \sum\limits_{i=m+3}^{n+1} \triangledown^2 s_i \]

再代入第三条式子,展开得 \[ \binom n6 f_n = \binom{m+1}6 f_{m+1} + \frac{q_{m+1}}{m+2} \sum\limits_{i=m+2}^n \binom i6 + \sum\limits_{i=m+2}^n \binom i6 \sum\limits_{j=m+3}^i \triangledown^2 s_j \]

显然有 \[ f_{m+1} = s_{m+1},f_{m+2} = s_{m+2} \]

然后 \[ p_{m+1} = (m+2)f_{m+2}-(m-4)f_{m+1} \]

再一通乱算可以得到 \[ f_n = s_n + \frac{12}{\binom n6} \sum\limits_{i=m+1}^{n-1} \frac{f_i}{(i+1)(i+2)}\left(\binom{n+1}7-\binom{i+2}7\right) \]

使用线段树维护即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <cstdio>
#include <utility>
#include <algorithm>
using namespace std;
const int N = 2e5 + 1;
const int mod = 998244353;
int L,q;
int a[N + 5],inv[N + 5];
int s1[N + 5],s2[N + 5];
namespace SEG
{
#define ls (p << 1)
#define rs (ls | 1)
struct node
{
int sum1,sum2;
int tag;
} seg[(N << 2) + 5];
inline void push(int p,int tl,int tr)
{
int mid = tl + tr >> 1;
if(seg[p].tag)
{
seg[ls].sum1 = (seg[ls].sum1 + (long long)seg[p].tag * (s1[mid] - s1[tl - 1] + mod)) % mod,
seg[ls].sum2 = (seg[ls].sum2 + (long long)seg[p].tag * (s2[mid] - s2[tl - 1] + mod)) % mod,
seg[ls].tag = (seg[ls].tag + seg[p].tag) % mod;
seg[rs].sum1 = (seg[rs].sum1 + (long long)seg[p].tag * (s1[tr] - s1[mid] + mod)) % mod,
seg[rs].sum2 = (seg[rs].sum2 + (long long)seg[p].tag * (s2[tr] - s2[mid] + mod)) % mod,
seg[rs].tag = (seg[rs].tag + seg[p].tag) % mod;
seg[p].tag = 0;
}
}
void update(int l,int r,int k,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
{
seg[p].sum1 = (seg[p].sum1 + (long long)k * (s1[tr] - s1[tl - 1] + mod)) % mod,
seg[p].sum2 = (seg[p].sum2 + (long long)k * (s2[tr] - s2[tl - 1] + mod)) % mod,
seg[p].tag = (seg[p].tag + k) % mod;
return ;
}
push(p,tl,tr);
int mid = tl + tr >> 1;
l <= mid && (update(l,r,k,ls,tl,mid),1);
r > mid && (update(l,r,k,rs,mid + 1,tr),1);
seg[p].sum1 = (seg[ls].sum1 + seg[rs].sum1) % mod,
seg[p].sum2 = (seg[ls].sum2 + seg[rs].sum2) % mod;
}
inline pair<int,int> operator+(const pair<int,int> &a,const pair<int,int> &b)
{
return make_pair((a.first + b.first) % mod,(a.second + b.second) % mod);
}
pair<int,int> query(int l,int r,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
return make_pair(seg[p].sum1,seg[p].sum2);
push(p,tl,tr);
int mid = tl + tr >> 1;
pair<int,int> ret(0,0);
l <= mid && (ret = ret + query(l,r,ls,tl,mid),1);
r > mid && (ret = ret + query(l,r,rs,mid + 1,tr),1);
return ret;
}
}
namespace BIT
{
#define lowbit(x) ((x) & -(x))
int c[N + 5];
inline void update(int x,int k)
{
for(;x <= L;x += lowbit(x))
c[x] = (c[x] + k) % mod;
}
inline void update(int l,int r,int k)
{
update(l,k),update(r + 1,(mod - k) % mod);
}
inline int query(int x)
{
int ret = 0;
for(;x;x -= lowbit(x))
ret = (ret + c[x]) % mod;
return ret;
}
}
int ans;
int main()
{
freopen("space.in","r",stdin),freopen("space.out","w",stdout);
scanf("%d%d",&L,&q),inv[1] = 1;
for(register int i = 2;i <= N;++i)
inv[i] = (long long)(mod - mod / i) * inv[mod % i] % mod;
for(register int i = 1;i <= L;++i)
scanf("%d",a + i),
s1[i] = (s1[i - 1] + (long long)inv[i + 1] * inv[i + 2]) % mod,
s2[i] = (s2[i - 1] + (long long)i * (i - 1) % mod * (i - 2) % mod * (i - 3) % mod * (i - 4) % mod * inv[5040]) % mod;
for(register int i = 1;i <= L;++i)
SEG::update(i,i,a[i],1,1,L),
BIT::update(i,i,a[i]);
for(int op,l,r,k,n,m;q;--q)
{
scanf("%d",&op);
if(!op)
scanf("%d%d%d",&l,&r,&k),
SEG::update(l,r,k,1,1,L),
BIT::update(l,r,k);
else
{
scanf("%d%d",&n,&m);
if(n <= max(m,5))
printf("%d\n",BIT::query(n));
else
{
pair<int,int> temp = SEG::query(m + 1,n - 1,1,1,L);
ans = (long long)temp.first * (n + 1) % mod * n % mod * (n - 1) % mod * (n - 2) % mod * (n - 3) % mod * (n - 4) % mod * (n - 5) % mod * inv[5040] % mod,
ans = (ans - temp.second + mod) % mod,
ans = (12LL * 720 * ans % mod * inv[n] % mod * inv[n - 1] % mod * inv[n - 2] % mod * inv[n - 3] % mod * inv[n - 4] % mod * inv[n - 5] + BIT::query(n)) % mod;
printf("%d\n",ans);
}
}
}
}