JZOJ 5231.序列问题

真是个码农题……

由于是乘,所以没法用单调栈来统计答案贡献。

考虑分治套路,于是问题变成了求区间左端点 [l,mid]\in [l,mid],区间右端点 [mid+1,r]\in [mid + 1,r] 的答案。
由于最值的单调性,我们可以多维护两个指针使得每层递归只需 O(n)O(n)

具体地说,我们按 midlmid \rightarrow l 顺序枚举 ii,并同时维护满足 maxj=mid+1j1ajmaxj=imidaj\max\limits_{j = mid + 1}^{j_1}{a_j} \le \max\limits_{j = i}^{mid}{a_j} 的最大的 j1j_1 和满足 minj=mid+1j2ajminj=imidaj\min\limits_{j = mid + 1}^{j_2}{a_j} \ge \min\limits_{j = i}^{mid}{a_j} 的最大的 j2j_2
于是我们就可以把计算 [mid+1,r][mid + 1,r]ii 的贡献分成三个部分,这里我们假设 j1<j2j_1 < j_2,反过来类似:

  1. [mid+1,j1][mid + 1,j_1],这一段的最值与 [i,mid][i,mid] 相同。
  2. [j1+1,j2][j_1 + 1,j_2],这一段仅最小值与 [i,mid][i,mid] 相同。
  3. [j2+1,r][j_2 + 1,r] 这一段的最值与 [i,mid][i,mid] 不同。

于是我们可以在递归进入时维护五个量:

  1. maxi=maxj=mid+1iajmax_i = \max\limits_{j = mid + 1}^i{a_j}
  2. mini=minj=mid+1iajmin_i = \min\limits_{j = mid + 1}^i{a_j}
  3. maxsumi=j=mid+1imaxjmaxsum_i = \sum\limits_{j = mid + 1}^i{max_j}
  4. minsumi=j=mid+1iminjminsum_i = \sum\limits_{j = mid + 1}^i{min_j}
  5. sumi=j=mid+1imaximinisum_i = \sum\limits_{j = mid + 1}^i{max_i \cdot min_i}

通过这五个量,三个部分的贡献就容易得出了。

此外,注意取模时,应该把较大的数先取模再乘。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 5e5;
const long long mod = 1e9 + 7;
int n;
long long a[N + 5],sum[N + 5],mx[N + 5],mn[N + 5],mxsum[N + 5],mnsum[N + 5];
long long ans;
void solve(int l,int r)
{
if(l == r)
{
ans = (ans + a[l] * a[l] % mod) % mod;
return ;
}
int mid = l + r >> 1;
solve(l,mid),solve(mid + 1,r);
mxsum[mid] = mnsum[mid] = sum[mid] = 0;
mx[mid] = -0x3f3f3f3f3f3f3f3fLL,mn[mid] = 0x3f3f3f3f3f3f3f3fLL;
for(register int i = mid + 1;i <= r;++i)
mx[i] = max(mx[i - 1],a[i]),
mn[i] = min(mn[i - 1],a[i]),
mxsum[i] = mxsum[i - 1] + mx[i],
mnsum[i] = mnsum[i - 1] + mn[i],
sum[i] = (sum[i - 1] + mx[i] * mn[i] % mod) % mod;
long long curmax = -0x3f3f3f3f3f3f3f3fLL,curmin = 0x3f3f3f3f3f3f3f3fLL;
for(register int i = mid,j1 = mid,j2 = mid;i >= l;--i)
{
curmax = max(curmax,a[i]),curmin = min(curmin,a[i]);
while(j1 < r && mx[j1 + 1] <= curmax)
++j1;
while(j2 < r && mn[j2 + 1] >= curmin)
++j2;
if(j1 <= j2)
ans = (ans +
curmax % mod * curmin % mod * (j1 - mid) % mod +
(mxsum[j2] - mxsum[j1]) % mod * curmin % mod +
(sum[r] - sum[j2]) % mod
) % mod;
else
ans = (ans +
curmax % mod * curmin % mod * (j2 - mid) % mod +
(mnsum[j1] - mnsum[j2]) % mod * curmax % mod +
(sum[r] - sum[j1]) % mod
) % mod;
}
}
int main()
{
freopen("seq.in","r",stdin);
freopen("seq.out","w",stdout);
scanf("%d",&n);
for(register int i = 1;i <= n;++i)
scanf("%lld",a + i);
solve(1,n);
printf("%lld\n",(ans + mod) % mod);
}