JZOJ 5231.序列问题

真是个码农题……

由于是乘,所以没法用单调栈来统计答案贡献。

考虑分治套路,于是问题变成了求区间左端点 \(\in [l,mid]\),区间右端点 \(\in [mid + 1,r]\) 的答案。
由于最值的单调性,我们可以多维护两个指针使得每层递归只需 \(O(n)\)

具体地说,我们按 \(mid \rightarrow l\) 顺序枚举 \(i\),并同时维护满足 \(\max\limits_{j = mid + 1}^{j_1}{a_j} \le \max\limits_{j = i}^{mid}{a_j}\) 的最大的 \(j_1\) 和满足 \(\min\limits_{j = mid + 1}^{j_2}{a_j} \ge \min\limits_{j = i}^{mid}{a_j}\) 的最大的 \(j_2\)
于是我们就可以把计算 \([mid + 1,r]\)\(i\) 的贡献分成三个部分,这里我们假设 \(j_1 < j_2\),反过来类似: 1. \([mid + 1,j_1]\),这一段的最值与 \([i,mid]\) 相同。 2. \([j_1 + 1,j_2]\),这一段仅最小值与 \([i,mid]\) 相同。 3. \([j_2 + 1,r]\) 这一段的最值与 \([i,mid]\) 不同。

于是我们可以在递归进入时维护五个量: 1. \(max_i = \max\limits_{j = mid + 1}^i{a_j}\) 2. \(min_i = \min\limits_{j = mid + 1}^i{a_j}\) 3. \(maxsum_i = \sum\limits_{j = mid + 1}^i{max_j}\) 4. \(minsum_i = \sum\limits_{j = mid + 1}^i{min_j}\) 5. \(sum_i = \sum\limits_{j = mid + 1}^i{max_i \cdot min_i}\)

通过这五个量,三个部分的贡献就容易得出了。

此外,注意取模时,应该把较大的数先取模再乘。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 5e5;
const long long mod = 1e9 + 7;
int n;
long long a[N + 5],sum[N + 5],mx[N + 5],mn[N + 5],mxsum[N + 5],mnsum[N + 5];
long long ans;
void solve(int l,int r)
{
if(l == r)
{
ans = (ans + a[l] * a[l] % mod) % mod;
return ;
}
int mid = l + r >> 1;
solve(l,mid),solve(mid + 1,r);
mxsum[mid] = mnsum[mid] = sum[mid] = 0;
mx[mid] = -0x3f3f3f3f3f3f3f3fLL,mn[mid] = 0x3f3f3f3f3f3f3f3fLL;
for(register int i = mid + 1;i <= r;++i)
mx[i] = max(mx[i - 1],a[i]),
mn[i] = min(mn[i - 1],a[i]),
mxsum[i] = mxsum[i - 1] + mx[i],
mnsum[i] = mnsum[i - 1] + mn[i],
sum[i] = (sum[i - 1] + mx[i] * mn[i] % mod) % mod;
long long curmax = -0x3f3f3f3f3f3f3f3fLL,curmin = 0x3f3f3f3f3f3f3f3fLL;
for(register int i = mid,j1 = mid,j2 = mid;i >= l;--i)
{
curmax = max(curmax,a[i]),curmin = min(curmin,a[i]);
while(j1 < r && mx[j1 + 1] <= curmax)
++j1;
while(j2 < r && mn[j2 + 1] >= curmin)
++j2;
if(j1 <= j2)
ans = (ans +
curmax % mod * curmin % mod * (j1 - mid) % mod +
(mxsum[j2] - mxsum[j1]) % mod * curmin % mod +
(sum[r] - sum[j2]) % mod
) % mod;
else
ans = (ans +
curmax % mod * curmin % mod * (j2 - mid) % mod +
(mnsum[j1] - mnsum[j2]) % mod * curmax % mod +
(sum[r] - sum[j1]) % mod
) % mod;
}
}
int main()
{
freopen("seq.in","r",stdin);
freopen("seq.out","w",stdout);
scanf("%d",&n);
for(register int i = 1;i <= n;++i)
scanf("%lld",a + i);
solve(1,n);
printf("%lld\n",(ans + mod) % mod);
}