洛谷 P10162 [DTCPC 2024] 序列 题解

这里给出一个基于序列分治的 O(nlogn)O(n\log n) 做法。

以下记 vi=aimax{ai1,ai+1}v_i=a_i-\max \{a_{i-1},a_{i+1}\}li=aiai+1l_i=a_i-a_{i+1}ri=aiai1r_i=a_i-a_{i-1}

考虑分治,初始区间为 [1,n][1,n],每次取区间中点 mm,求出所有跨过区间中点的子区间的贡献(即所有满足 llm<rrl\leq l'\leq m<r'\leq r 的子区间 [l,r][l',r'] 的贡献),然后对中点的两侧分别递归。

设当前区间为 [l,r][l,r],区间中点为 mm。由于我们考虑的区间一定包含 mmm+1m+1,所以我们可以计算 pi=max{vm+1,vm+2,,vi1,ri}p_i=\max\{v_{m+1},v_{m+2},\cdots,v_{i-1},r_i\}si=max{li,vi+1,vi+2,,vm}s_i=\max\{l_i,v_{i+1},v_{i+2},\cdots,v_{m}\},那么一个子区间 [l,r][l',r'] 的贡献就是 max{sl,pr}\max\{s_{l'},p_{r'}\}。对所有子区间求和这个值的直接做法是排序后归并,但是这个做法再套上外面的分治会达到 O(nlog2n)O(n\log^2n) 的复杂度。

为了去掉排序,我们可以考虑通过在分治过程中归并的方式维护出有序的 pp 数组和 ss 数组。具体来说,我们对于区间 [l,r][l,r],维护出所有的 pi=max{vl,vl+1,,vi1,ri}p_i'=\max\{v_{l},v_{l+1},\cdots,v_{i-1},r_i\}si=max{li,vi+1,vi+2,,vr}s_i'=\max\{l_i,v_{i+1},v_{i+2},\cdots,v_{r}\}。在对 [l,r][l,r] 的中点两侧分别递归之后,[l,m][l,m] 一侧的所有 pip_i' 均不变,可以直接继承;[m+1,r][m+1,r] 一侧的所有 pip_i' 继承上来的过程中需要对 vl,vl+1,,vmv_{l},v_{l+1},\cdots,v_m 取 max,而一个有序序列对一个值取 max 之后仍然是有序的,所以我们可以分别得到有序的 pl,,mp'_{l,\cdots,m}pm+1,,rp'_{m+1,\cdots,r}。再将它们归并即可得到当前区间的 pp' 数组排序后的结果。ss' 的维护是对称的。求答案的时候,只需要分别取出左侧的有序的 ss' 数组和右侧的有序的 pp' 数组进行归并即可。

这样对一个区间分治就只需要进行若干次归并,复杂度就是线性的。再套上分治就达到了 O(nlogn)O(n\log n) 的目标复杂度,可以通过本题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <bits/stdc++.h>
using namespace std;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int qread() {
char c = getchar();
int x = 0, f = 1;
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + c - 48;
c = getchar();
}
return x * f;
}

typedef unsigned int ui;
const int N = 1000005;
int n, a[N], pmx[N], smx[N], tmp[N];
ui ans = 0;

inline void Read() {
n = qread();
for (int i = 1;i <= n;i++) a[i] = qread();
}

inline void DnC(int l, int r) {
if (l == r) {
pmx[l] = a[l] - a[l - 1];
smx[l] = a[l] - a[l + 1];
return;
}
int mid = l + r >> 1;
DnC(l, mid); DnC(mid + 1, r);
int i = l, j = mid + 1;
while (i <= mid && j <= r) {
if (smx[i] <= pmx[j]) {
ans += (ui)smx[i] * (ui)(j - mid - 1);
i++;
} else {
ans += (ui)pmx[j] * (ui)(i - l);
j++;
}
}
while (i <= mid) {
ans += (ui)smx[i] * (ui)(j - mid - 1);
i++;
}
while (j <= r) {
ans += (ui)pmx[j] * (ui)(i - l);
j++;
}
int mxv = -0x3f3f3f3f;
for (int i = l;i <= mid;i++) mxv = max(mxv, a[i] - max(a[i + 1], a[i - 1]));
for (int i = mid + 1;i <= r;i++) pmx[i] = max(pmx[i], mxv);
mxv = -0x3f3f3f3f;
for (int i = mid + 1;i <= r;i++) mxv = max(mxv, a[i] - max(a[i + 1], a[i - 1]));
for (int i = l;i <= mid;i++) smx[i] = max(smx[i], mxv);
merge(pmx + l, pmx + mid + 1, pmx + mid + 1, pmx + r + 1, tmp + l);
for (int i = l;i <= r;i++) pmx[i] = tmp[i];
merge(smx + l, smx + mid + 1, smx + mid + 1, smx + r + 1, tmp + l);
for (int i = l;i <= r;i++) smx[i] = tmp[i];
}

int main() {
Read();
ans = 0;
DnC(1, n);
cout << ans << endl;
return 0;
}