线段树合并学习笔记

线段树合并学习笔记

前置芝士

动态开点线段树

一般我们写线段树都是在刚开始调用一个 build 函数建立出所有节点。但实际上有很多节点并没有保存信息,这些节点是多余的。所以就有的动态开点线段树,只有在用到时才建立这个节点,这种思想节省了很多空间,节点与节点之间的灵活性也更高,为线段树合并和主席树等算法提供了基础。我们简单了解一下动态开点线段树的代码实现。

线段树节点不同于普通线段树,需要存储左右孩子的位置。

1
2
3
4
5
6
struct Segmemt_tree {
int l, r, ls, rs, ...; // ls, rs 存储节点左右儿子的 id
#define lid tr[id].ls
#define rid tr[id].rs;
}tr[N * 4];
int cnt = 0;

动态开点线段树的修改操作,基本与普通线段树相同:

1
2
3
4
5
6
7
8
9
10
11
12
void Add(int &id, int l, int r, ...) { // 注意 id 是引用
if (!id) id = ++ cnt; // 若没有这个节点则新建一个
tr[id].l = l, tr[id].r = r;
if (l == r) {
...
return;
}
int mid = tr[id].l + tr[id].r >> 1;
if (...) Add(lid, l, mid, ...);
if (...) Add(rid, mid + 1, r, ...);
pushup(id);
}

查询:

1
2
3
4
5
int Ask(int id, int l, int r, ...) {
if (!id) return 0;
// 因为是询问,没有这个点就可以直接返回了
// 后面与普通线段树基本相同
}

权值线段树

我们的线段树合并,99% 的情况合并的是权值线段树。 —— 某同机房巨佬

权值线段树类似于权值树状数组,是在值域上建立的的线段树,维护值的信息,如出现次数等。

其实和正常线段树没啥区别,就是因为维护的东西比较特殊就被拎出来了 Qwq。

权值线段树应该挺简单的,我们提一下就过吧。

线段树合并

顾名思义,合并两个线段树。一般用来维护与深度相关的信息、离线处理一些问题。

假设现在我们有两棵线段树 $x$ 和 $y$ ,我们考虑怎么合并,当然两棵线段树维护的区间显然要相等。

  • 如果 $x$ 或 $y$ 是空的(因为是动态开点),直接返回不空的那个;
  • 如果到达边界,即 $x$ 和 $y$ 维护的区间只有一个元素,直接合并维护的信息;
  • 否则递归合并 $x$ 、 $y$ 的左子树和右子树,用两个子树向上更新出合并后的节点。

考虑这样合并的复杂度是多少,显然是 $\Theta \left( 节点数\right)$,一般就是 $\Theta(n \log n)$ 了。

代码实现

在动态开点线段树的基础上多了一个 $\tt{Merge}$ (合并)函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void Merge(int &x, int y, int l, int r) {
// 你可以把两个线段树合并到一个新点上,这里我是把 y 合并到 x 上
if (!y) return;
if (!x) return x = y, void();
// 有一个节点为空的情况
if (l == r) { // 到达边界,直接合并信息
...
return;
}
int mid = tr[x].l + tr[x].r >> 1;
Merge(tr[x].ls, tr[y].ls, l, mid);
Merge(tr[x].rs, tr[y].rs, mid + 1, r);
pushup(x); // 合并左右子树并更新
}

一些题目

[USACO17JAN]Promotion Counting P

link

给定一棵树 $n$ 个节点、以 $1$ 为根的树,树上每个点有权值 $p_i$。

求以每个点为根的子树内权值大于它的节点个数。

$1 \le n \le 10^5$,$1 \le p_i \le 10^9$

线段树合并裸题,码个板子练练手。离散化后对于每个点建立权值线段树,维护每个值出现的次数,$\tt{dfs}$ 将所有子节点线段树合并到父节点上,对于 $p_i$ 在线段树上查询 $\left[p_i + 1 ,n\right]$ 区间和即可。

刚开始每个节点的线段树中只插入一个值,因此线段树上只有 $\log n$ 个节点;合并操作中不会新建节点,因此总空间复杂度就是 $\Theta(n \log n)$,每个节点会进行 $\Theta(\log n)$ 的合并,所以时间复杂度也是 $\Theta(n \log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
using namespace std;

inline int read() {
int x = 0, f = 0; char c = 0;
while (!isdigit(c)) f |= c == '-', c = getchar();
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return f ? -x : x;
}

#define N 100010
#define pb push_back

int n, a[N], rt[N], res[N];
vector<int> Hash, e[N];

#define lid tr[id].ls
#define rid tr[id].rs
struct Segment_tree {
int l, r, ls, rs, sum;
}tr[N * 20];
int cnt = 0;
int New() {return ++ cnt;}
void pushup(int id) {
tr[id].sum = tr[lid].sum + tr[rid].sum;
}
void Add(int &id, int l, int r, int x) {
if (!id) id = New();
tr[id].l = l, tr[id].r = r;
if (l == r) {
tr[id].sum ++;
return;
}
int mid = tr[id].l + tr[id].r >> 1;
if (x <= mid) Add(lid, l, mid, x);
if (x > mid) Add(rid, mid + 1, r, x);
pushup(id);
}
int Ask(int id, int l, int r) {
if (!id) return 0;
if (l <= tr[id].l && tr[id].r <= r) {
return tr[id].sum;
}
int mid = tr[id].l + tr[id].r >> 1, val = 0;
if (l <= mid) val += Ask(lid, l, r);
if (r > mid) val += Ask(rid, l, r);
return val;
}
void Merge(int &x, int y, int l, int r) {
if (!x) return x = y, void();
if (!y) return;
if (l == r) {
tr[x].sum += tr[y].sum;
return;
}
int mid = tr[x].l + tr[x].r >> 1;
Merge(tr[x].ls, tr[y].ls, l, mid);
Merge(tr[x].rs, tr[y].rs, mid + 1, r);
pushup(x);
}

void dfs(int x, int fa) {
Add(rt[x] = New(), 1, n, a[x]);
for (auto y : e[x]) if (y != fa) {
dfs(y, x);
Merge(rt[x], rt[y], 1, n);
}
res[x] = Ask(rt[x], a[x] + 1, n);
}

int main() {
n = read();

for (int i = 1; i <= n; i ++) {
a[i] = read(), Hash.pb(a[i]);
}
sort(Hash.begin(), Hash.end());
unique(Hash.begin(), Hash.end());
for (int i = 1; i <= n; i ++) {
a[i] = lower_bound(Hash.begin(), Hash.end(), a[i]) - Hash.begin() + 1;
}

for (int i = 2; i <= n; i ++) {
int x = read();
e[i].pb(x), e[x].pb(i);
}

dfs(1, 0);

for (int i = 1; i <= n; i ++) {
printf("%d\n", res[i]);
}
return 0;
}

CF600E Lomsat gelral

link

给定一棵 $n$ 个节点,$1$ 号节点为根节点的树,每个节点有权值 $c_i$。

求出以所有节点为根的子树内出现次数最多的权值的和(因为可能有多个最大权值)。

$1 \le n \le 10^5$,$1 \le c_i \le n$。

线段树合并裸题 ++。

对于每个节点开一棵权值线段树(良心出题人,不用离散化),维护区间最大最大值和所有最大值的和, $\tt{dfs}$ 合并线段树即可。

注意答案可以达到 $10 ^ 5 \times 10 ^ 5 = 10 ^ {10}$,别忘了 longlong

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>
using namespace std;

inline int read() {
int x = 0, f = 0; char c = 0;
while (!isdigit(c)) f |= c == '-', c = getchar();
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return f ? -x : x;
}

#define N 100010
#define pb push_back

int n, a[N], rt[N];
long long res[N];
vector<int> e[N];

struct Segment_tree {
int l, r, ls, rs;
long long Mx, sum;
#define lid tr[id].ls
#define rid tr[id].rs
}tr[N * 20]; int cnt = 0;
int New() {return ++ cnt;}
void pushup(int id) {
tr[id].Mx = max(tr[lid].Mx, tr[rid].Mx);
if (tr[lid].Mx > tr[rid].Mx) tr[id].sum = tr[lid].sum;
if (tr[lid].Mx < tr[rid].Mx) tr[id].sum = tr[rid].sum;
if (tr[lid].Mx == tr[rid].Mx) tr[id].sum = tr[lid].sum + tr[rid].sum;
}
void Add(int &id, int l, int r, int x) {
if (!id) id = New();
tr[id].l = l, tr[id].r = r;
if (l == r) {
tr[id].Mx = 1, tr[id].sum = x;
return;
}
int mid = tr[id].l + tr[id].r >> 1;
if (x <= mid) Add(lid, l, mid, x);
if (x > mid) Add(rid, mid + 1, r, x);
pushup(id);
}
void Merge(int &x, int y, int l, int r) {
if (!y) return;
if (!x) return x = y, void();
if (l == r) {
tr[x].Mx += tr[y].Mx;
return;
}
int mid = l + r >> 1;
Merge(tr[x].ls, tr[y].ls, l, mid);
Merge(tr[x].rs, tr[y].rs, mid + 1, r);
pushup(x);
}

void dfs(int x, int fa) {
Add(rt[x] = New(), 1, n, a[x]);
for (auto y : e[x]) if (y != fa) {
dfs(y, x), Merge(rt[x], rt[y], 1, n);
}
res[x] = tr[rt[x]].sum;
}

int main() {
n = read();
for (int i = 1; i <= n; i ++) {
a[i] = read();
}
for (int i = 1; i < n; i ++) {
int x = read(), y = read();
e[x].pb(y), e[y].pb(x);
}

dfs(1, 0);

for (int i = 1; i <= n; i ++) {
printf("%lld ", res[i]);
} puts("");
return 0;
}

[POI2011]ROT-Tree Rotations

link

给定一棵 $n$ 个叶子节点二叉树,其中每个点要么没有子树,要么有两棵子树,叶子节点有权值 $x$。

你可以任意交换一个节点的左右子树,使得先序遍历得到的叶子节点权值的排列逆序对最少。

$1 \le n \le 2 \times 10 ^ 5$,$0 \le x \le n$。

考虑到先序遍历是先父亲再左右子树,因此子树交换并不影响外部的逆序对。同样使用权值线段树加线段树合并,$\tt{dfs}$ 考虑一个节点,我们计算出交换与不交换的两种贡献,取最小的即可,我们可以在合并线段树时计算两个节点之间的逆序对个数。

假设现在要合并 $x$ 和 $y$,不交换就是 tr[tr[x].rs].sum * tr[tr[y].ls].sum,交换就是 tr[tr[x].ls].sum * tr[tr[y].rs].sum,接着在合并左右子树时处理剩下的逆序对。

这样说可能比较难理解,举个栗子:

假设现在左右子树 $x$、$y$ 的排列分别为 1 2 3 31 1 2 3

不交换有 $8$ 个逆序对:2,1 2,1 3,1 3,1 3,2 3,1 3,1 3,2

交换有 $3$ 个:2,1 3,1 3,2

建出两棵线段树:

节点内的红字这个节点的 $\tt{sum}$ 值,即这个节点维护区间的值的数量。

刚开始两棵子树都在 $\tt{[1,3]}$;

  • 用 $x$ 的 $\tt{[3,3]}$ 区间乘以 $y$ 的 $\tt{[1,2]}$,$2 \times 3 = 6$,代表不交换的逆序对 3,1 3,1 3,2 3,1 3,1 3,2
  • 用 $x$ 的 $\tt{[1,2]}$ 区间乘以 $y$ 的 $\tt{[3,3]}$,$2 \times 1 = 2$,代表交换的逆序对 3,1 3,2

但是这样还有一些逆序对没找全啊?所以要在接下来 $x$、$y$ 的左右子树合并时继续计算。

合并 $x$ 、$y$ 的左子树 $\tt{[1,2]}$:

  • 用 $x$ 的 $\tt{[2,2]}$ 区间乘以 $y$ 的 $\tt{[1,1]}$,$1 \times 2 = 2$,代表不交换的逆序对 2,1 2,1
  • 用 $x$ 的 $\tt{[1,1]}$ 区间乘以 $y$ 的 $\tt{[2,2]}$,$1 \times 1 = 1$,代表交换的逆序对 2,1

合并 $x$ 、$y$ 的左子树 $\tt{[3,3]}$,一个点显然没有逆序对,不用计算。

这样不交换有 $6+2=8$ 个逆序对,交换有 $2+1=3$ 个逆序对。这就对了嘛。

这道题同样需要 long long,另外因为是 $n$ 个叶子节点不是 $n$ 个节点,数组要开大一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <bits/stdc++.h>
using namespace std;

inline int read() {
int x = 0, f = 0; char c = 0;
while (!isdigit(c)) f |= c == '-', c = getchar();
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return f ? -x : x;
}

#define N 400010
typedef int Ary[N];

int n, tot = 0;
long long res, s1, s2;
Ary a, L, R, rt;

int In(int k = 0) {
a[k = ++ tot] = read();
if (a[k] == 0) {
L[k] = In(), R[k] = In();
}
return k;
}

#define lid tr[id].ls
#define rid tr[id].rs
struct Segmemt_tree {
int l, r, ls, rs, sum;
}tr[N * 20];
int cnt = 0;
int New() {return ++ cnt;}
void pushup(int id) {
tr[id].sum = tr[lid].sum + tr[rid].sum;
}
void Add(int &id, int l, int r, int x) {
if (!id) id = New();
tr[id].l = l, tr[id].r = r;
if (l == r) {
tr[id].sum ++;
return;
}
int mid = tr[id].l + tr[id].r >> 1;
if (x <= mid) Add(lid, l, mid, x);
if (x > mid) Add(rid, mid + 1, r, x);
pushup(id);
}
void Merge(int &x, int y, int l, int r) {
if (!x) return x = y, void();
if (!y) return;
if (l == r) {
tr[x].sum += tr[y].sum;
return;
}
int mid = tr[x].l + tr[x].r >> 1;
s1 += 1ll * tr[tr[x].ls].sum * tr[tr[y].rs].sum;
s2 += 1ll * tr[tr[x].rs].sum * tr[tr[y].ls].sum;
Merge(tr[x].ls, tr[y].ls, l, mid);
Merge(tr[x].rs, tr[y].rs, mid + 1, r);
pushup(x);
}

void dfs(int x) {
rt[x] = New();
if (a[x] > 0) Add(rt[x], 1, n, a[x]);
if (L[x] > 0) {
dfs(L[x]), s1 = s2 = 0, Merge(rt[x], rt[L[x]], 1, n);
}
if (R[x] > 0) {
dfs(R[x]), s1 = s2 = 0, Merge(rt[x], rt[R[x]], 1, n);
}
res += min(s1, s2);
}

int main() {
n = read(), In(), dfs(1);
printf("%lld\n", res);
return 0;
}

CF208E Blood Cousins

link

给定一片 $n$ 个点的森林,森林中的树都有根。

$m$ 次询问,每次询问一个点 $x$ 与多少个点有相同的 $K$ 级祖先。

$1 \le n \le 10^5$,$1 \le m \le 10^5$

离线做法,首先用倍增对于每个询问求出 $x$ 的 $K$ 级祖先,用 $\tt{vector}$ 或链表把询问挂在祖先上,问题转换成求一个点有多少个 $K$ 级儿子的问题。

对每个节点,以深度为下标建立权值线段树,进行 $\tt{dfs}$。若 $\tt{dfs}$ 到一个节点 $x$,在线段树插入当前点的深度并与子节点的线段树合并。然后遍历所有询问,对于询问这个点有多少个 $K$ 级儿子,在线段树中查询深度为 $dep_x + K$ 的节点有多少个即可。别忘了答案要减一。

时间复杂度 $\Theta(n \log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <bits/stdc++.h>
using namespace std;

inline int read() {
int x = 0, f = 0; char c = 0;
while (!isdigit(c)) f |= c == '-', c = getchar();
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return f ? -x : x;
}

#define N 100010
#define pb push_back
#define PII pair<int, int>
#define mp make_pair
#define fi first
#define se second

int n, m, res[N], f[N][22], rt[N], dep[N];
vector<int> e[N], root;
vector<PII> q[N];

struct Segment_tree {
int l, r, ls, rs, sum;
#define lid tr[id].ls
#define rid tr[id].rs
}tr[N * 20];
int cnt = 0;
void Add(int &id, int l, int r, int x) {
if (!id) id = ++ cnt;
tr[id].l = l, tr[id].r = r;
if (l == r) {
tr[id].sum ++;
return;
}
int mid = tr[id].l + tr[id].r >> 1;
if (x <= mid) Add(lid, l, mid, x);
if (x > mid) Add(rid, mid + 1, r, x);
}
int Ask(int id, int l, int r, int x) {
if (!id) return 0;
if (l == r) return tr[id].sum;
int mid = tr[id].l + tr[id].r >> 1;
if (x <= mid) return Ask(lid, l, mid, x);
if (x > mid) return Ask(rid, mid + 1, r, x);
}
void Merge(int &x, int y, int l, int r) {
if (!x || !y) return x += y, void();
if (l == r) {
tr[x].sum += tr[y].sum;
return;
}
int mid = tr[x].l + tr[x].r >> 1;
Merge(tr[x].ls, tr[y].ls, l, mid);
Merge(tr[x].rs, tr[y].rs, mid + 1, r);
}

void dfs2(int x, int fa) {
Add(rt[x] = ++ cnt, 1, n, dep[x]);
for (auto y : e[x]) if (y != fa) {
dfs2(y, x), Merge(rt[x], rt[y], 1, n);
}
for (auto i : q[x]) {
res[i.fi] = Ask(rt[x], 1, n, dep[x] + i.se) - 1;
}
}

void dfs1(int x, int fa) {
dep[x] = dep[fa] + 1;
f[x][0] = fa;
for (int i = 1; i <= 20; i ++) {
f[x][i] = f[f[x][i - 1]][i - 1];
}
for (auto y : e[x]) {
if (y != fa) dfs1(y, x);
}
}

int main() {
n = read();
for (int i = 1; i <= n; i ++) {
int x = read();
if (x) e[x].pb(i), e[i].pb(x);
else root.pb(i);
}

for (auto x : root) dfs1(x, 0);

m = read();
for (int i = 1; i <= m; i ++) {
int x = read(), y = read(), fa = x;
for (int i = 0, j = y; i <= 20; i ++, j >>= 1) {
if (j & 1) fa = f[fa][i];
}
q[fa].pb(mp(i, y));
}

for (auto x : root) dfs2(x, 0);

for (int i = 1; i <= m; i ++) {
printf("%d ", res[i]);
} puts("");
return 0;
}

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!