线段树合并学习笔记

线段树合并学习笔记

前置芝士

动态开点线段树

一般我们写线段树都是在刚开始调用一个 build 函数建立出所有节点。但实际上有很多节点并没有保存信息,这些节点是多余的。所以就有的动态开点线段树,只有在用到时才建立这个节点,这种思想节省了很多空间,节点与节点之间的灵活性也更高,为线段树合并和主席树等算法提供了基础。我们简单了解一下动态开点线段树的代码实现。

线段树节点不同于普通线段树,需要存储左右孩子的位置。

struct Segmemt_tree {
	int l, r, ls, rs, ...; // ls, rs 存储节点左右儿子的 id
	#define lid tr[id].ls
	#define rid tr[id].rs;
}tr[N * 4];
int cnt = 0;

动态开点线段树的修改操作,基本与普通线段树相同:

void Add(int &id, int l, int r, ...) { // 注意 id 是引用
	if (!id) id = ++ cnt; // 若没有这个节点则新建一个
	tr[id].l = l, tr[id].r = r;
	if (l == r) {
		...
		return;
	}
	int mid = tr[id].l + tr[id].r >> 1;
	if (...) Add(lid, l, mid, ...);
	if (...) Add(rid, mid + 1, r, ...);
	pushup(id);
}

查询:

int Ask(int id, int l, int r, ...) {
	if (!id) return 0;
	// 因为是询问,没有这个点就可以直接返回了
	// 后面与普通线段树基本相同
}

权值线段树

我们的线段树合并,99% 的情况合并的是权值线段树。 —— 某同机房巨佬

权值线段树类似于权值树状数组,是在值域上建立的的线段树,维护值的信息,如出现次数等。

其实和正常线段树没啥区别,就是因为维护的东西比较特殊就被拎出来了 Qwq。

权值线段树应该挺简单的,我们提一下就过吧。

线段树合并

顾名思义,合并两个线段树。一般用来维护与深度相关的信息、离线处理一些问题。

假设现在我们有两棵线段树 xxyy ,我们考虑怎么合并,当然两棵线段树维护的区间显然要相等。

  • 如果 xxyy 是空的(因为是动态开点),直接返回不空的那个;
  • 如果到达边界,即 xxyy 维护的区间只有一个元素,直接合并维护的信息;
  • 否则递归合并 xxyy 的左子树和右子树,用两个子树向上更新出合并后的节点。

考虑这样合并的复杂度是多少,显然是 Θ(节点数)\Theta \left( 节点数\right),一般就是 Θ(nlogn)\Theta(n \log n) 了。

代码实现

在动态开点线段树的基础上多了一个 Merge\tt{Merge} (合并)函数:

void Merge(int &x, int y, int l, int r) {
	// 你可以把两个线段树合并到一个新点上,这里我是把 y 合并到 x 上
	if (!y) return;
	if (!x) return x = y, void();
	// 有一个节点为空的情况
	if (l == r) { // 到达边界,直接合并信息
		...
		return;
	}
	int mid = tr[x].l + tr[x].r >> 1;
	Merge(tr[x].ls, tr[y].ls, l, mid);
	Merge(tr[x].rs, tr[y].rs, mid + 1, r);
	pushup(x); // 合并左右子树并更新
}

一些题目

[USACO17JAN]Promotion Counting P

link

给定一棵树 nn 个节点、以 11 为根的树,树上每个点有权值 pip_i

求以每个点为根的子树内权值大于它的节点个数。

1n1051 \le n \le 10^51pi1091 \le p_i \le 10^9

线段树合并裸题,码个板子练练手。离散化后对于每个点建立权值线段树,维护每个值出现的次数,dfs\tt{dfs} 将所有子节点线段树合并到父节点上,对于 pip_i 在线段树上查询 [pi+1,n]\left[p_i + 1 ,n\right] 区间和即可。

刚开始每个节点的线段树中只插入一个值,因此线段树上只有 logn\log n 个节点;合并操作中不会新建节点,因此总空间复杂度就是 Θ(nlogn)\Theta(n \log n),每个节点会进行 Θ(logn)\Theta(\log n) 的合并,所以时间复杂度也是 Θ(nlogn)\Theta(n \log n)

#include <bits/stdc++.h>
using namespace std;

inline int read() {
	int x = 0, f = 0; char c = 0;
	while (!isdigit(c)) f |= c == '-', c = getchar();
	while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
	return f ? -x : x;
}

#define N 100010
#define pb push_back

int n, a[N], rt[N], res[N];
vector<int> Hash, e[N];

#define lid tr[id].ls
#define rid tr[id].rs
struct Segment_tree {
	int l, r, ls, rs, sum;
}tr[N * 20];
int cnt = 0;
int New() {return ++ cnt;}
void pushup(int id) {
	tr[id].sum = tr[lid].sum + tr[rid].sum;
}
void Add(int &id, int l, int r, int x) {
	if (!id) id = New();
	tr[id].l = l, tr[id].r = r;
	if (l == r) {
		tr[id].sum ++;
		return;
	}
	int mid = tr[id].l + tr[id].r >> 1;
	if (x <= mid) Add(lid, l, mid, x);
	if (x > mid) Add(rid, mid + 1, r, x);
	pushup(id);
}
int Ask(int id, int l, int r) {
	if (!id) return 0;
	if (l <= tr[id].l && tr[id].r <= r) {
		return tr[id].sum;
	}
	int mid = tr[id].l + tr[id].r >> 1, val = 0;
	if (l <= mid) val += Ask(lid, l, r);
	if (r > mid) val += Ask(rid, l, r);
	return val;
}
void Merge(int &x, int y, int l, int r) {
	if (!x) return x = y, void();
	if (!y) return;
	if (l == r) {
		tr[x].sum += tr[y].sum;
		return;
	}
	int mid = tr[x].l + tr[x].r >> 1;
	Merge(tr[x].ls, tr[y].ls, l, mid);
	Merge(tr[x].rs, tr[y].rs, mid + 1, r);
	pushup(x);
}

void dfs(int x, int fa) {
	Add(rt[x] = New(), 1, n, a[x]);
	for (auto y : e[x]) if (y != fa) {
		dfs(y, x);
		Merge(rt[x], rt[y], 1, n);
	}
	res[x] = Ask(rt[x], a[x] + 1, n);
}

int main() {
	n = read();

	for (int i = 1; i <= n; i ++) {
		a[i] = read(), Hash.pb(a[i]);
	}
	sort(Hash.begin(), Hash.end());
	unique(Hash.begin(), Hash.end());
	for (int i = 1; i <= n; i ++) {
		a[i] = lower_bound(Hash.begin(), Hash.end(), a[i]) - Hash.begin() + 1;
	}

	for (int i = 2; i <= n; i ++) {
		int x = read();
		e[i].pb(x), e[x].pb(i);
	}

	dfs(1, 0);

	for (int i = 1; i <= n; i ++) {
		printf("%d\n", res[i]);
	}
	return 0;
}

CF600E Lomsat gelral

link

给定一棵 nn 个节点,11 号节点为根节点的树,每个节点有权值 cic_i

求出以所有节点为根的子树内出现次数最多的权值的和(因为可能有多个最大权值)。

1n1051 \le n \le 10^51cin1 \le c_i \le n

线段树合并裸题 ++。

对于每个节点开一棵权值线段树(良心出题人,不用离散化),维护区间最大最大值和所有最大值的和, dfs\tt{dfs} 合并线段树即可。

注意答案可以达到 105×105=101010 ^ 5 \times 10 ^ 5 = 10 ^ {10},别忘了 longlong

#include <bits/stdc++.h>
using namespace std;

inline int read() {
    int x = 0, f = 0; char c = 0;
    while (!isdigit(c)) f |= c == '-', c = getchar();
    while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
    return f ? -x : x;
}

#define N 100010
#define pb push_back

int n, a[N], rt[N];
long long res[N];
vector<int> e[N];

struct Segment_tree {
    int l, r, ls, rs;
    long long Mx, sum;
    #define lid tr[id].ls
    #define rid tr[id].rs
}tr[N * 20]; int cnt = 0;
int New() {return ++ cnt;}
void pushup(int id) {
    tr[id].Mx = max(tr[lid].Mx, tr[rid].Mx);
    if (tr[lid].Mx > tr[rid].Mx) tr[id].sum = tr[lid].sum;
    if (tr[lid].Mx < tr[rid].Mx) tr[id].sum = tr[rid].sum;
    if (tr[lid].Mx == tr[rid].Mx) tr[id].sum = tr[lid].sum + tr[rid].sum;
}
void Add(int &id, int l, int r, int x) {
    if (!id) id = New();
    tr[id].l = l, tr[id].r = r;
    if (l == r) {
        tr[id].Mx = 1, tr[id].sum = x;
        return;
    }
    int mid = tr[id].l + tr[id].r >> 1;
    if (x <= mid) Add(lid, l, mid, x);
    if (x > mid) Add(rid, mid + 1, r, x);
    pushup(id);
}
void Merge(int &x, int y, int l, int r) {
    if (!y) return;
    if (!x) return x = y, void();
    if (l == r) {
        tr[x].Mx += tr[y].Mx;
        return;
    }
    int mid = l + r >> 1;
    Merge(tr[x].ls, tr[y].ls, l, mid);
    Merge(tr[x].rs, tr[y].rs, mid + 1, r);
    pushup(x);
}

void dfs(int x, int fa) {
    Add(rt[x] = New(), 1, n, a[x]);
    for (auto y : e[x]) if (y != fa) {
        dfs(y, x), Merge(rt[x], rt[y], 1, n);
    }
    res[x] = tr[rt[x]].sum;
}

int main() {
    n = read();
    for (int i = 1; i <= n; i ++) {
        a[i] = read();
    }
    for (int i = 1; i < n; i ++) {
        int x = read(), y = read();
        e[x].pb(y), e[y].pb(x);
    }

    dfs(1, 0);

    for (int i = 1; i <= n; i ++) {
        printf("%lld ", res[i]);
    } puts("");
    return 0;
}

[POI2011]ROT-Tree Rotations

link

给定一棵 nn叶子节点二叉树,其中每个点要么没有子树,要么有两棵子树,叶子节点有权值 xx

你可以任意交换一个节点的左右子树,使得先序遍历得到的叶子节点权值的排列逆序对最少。

1n2×1051 \le n \le 2 \times 10 ^ 50xn0 \le x \le n

考虑到先序遍历是先父亲再左右子树,因此子树交换并不影响外部的逆序对。同样使用权值线段树加线段树合并,dfs\tt{dfs} 考虑一个节点,我们计算出交换与不交换的两种贡献,取最小的即可,我们可以在合并线段树时计算两个节点之间的逆序对个数。

假设现在要合并 xxyy,不交换就是 tr[tr[x].rs].sum * tr[tr[y].ls].sum,交换就是 tr[tr[x].ls].sum * tr[tr[y].rs].sum,接着在合并左右子树时处理剩下的逆序对。

这样说可能比较难理解,举个栗子:

假设现在左右子树 xxyy 的排列分别为 1 2 3 31 1 2 3

不交换有 88 个逆序对:2,1 2,1 3,1 3,1 3,2 3,1 3,1 3,2

交换有 33 个:2,1 3,1 3,2

建出两棵线段树:

节点内的红字这个节点的 sum\tt{sum} 值,即这个节点维护区间的值的数量。

刚开始两棵子树都在 [1,3]\tt{[1,3]}

  • xx[3,3]\tt{[3,3]} 区间乘以 yy[1,2]\tt{[1,2]}2×3=62 \times 3 = 6,代表不交换的逆序对 3,1 3,1 3,2 3,1 3,1 3,2
  • xx[1,2]\tt{[1,2]} 区间乘以 yy[3,3]\tt{[3,3]}2×1=22 \times 1 = 2,代表交换的逆序对 3,1 3,2

但是这样还有一些逆序对没找全啊?所以要在接下来 xxyy 的左右子树合并时继续计算。

合并 xxyy 的左子树 [1,2]\tt{[1,2]}

  • xx[2,2]\tt{[2,2]} 区间乘以 yy[1,1]\tt{[1,1]}1×2=21 \times 2 = 2,代表不交换的逆序对 2,1 2,1
  • xx[1,1]\tt{[1,1]} 区间乘以 yy[2,2]\tt{[2,2]}1×1=11 \times 1 = 1,代表交换的逆序对 2,1

合并 xxyy 的左子树 [3,3]\tt{[3,3]},一个点显然没有逆序对,不用计算。

这样不交换有 6+2=86+2=8 个逆序对,交换有 2+1=32+1=3 个逆序对。这就对了嘛。

这道题同样需要 long long,另外因为是 nn 个叶子节点不是 nn 个节点,数组要开大一点。

#include <bits/stdc++.h>
using namespace std;

inline int read() {
	int x = 0, f = 0; char c = 0;
	while (!isdigit(c)) f |= c == '-', c = getchar();
	while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
	return f ? -x : x;
}

#define N 400010
typedef int Ary[N];

int n, tot = 0;
long long res, s1, s2;
Ary a, L, R, rt;

int In(int k = 0) {
	a[k = ++ tot] = read();
	if (a[k] == 0) {
		L[k] = In(), R[k] = In();
	}
	return k;
}

#define lid tr[id].ls
#define rid tr[id].rs
struct Segmemt_tree {
	int l, r, ls, rs, sum;
}tr[N * 20];
int cnt = 0;
int New() {return ++ cnt;}
void pushup(int id) {
	tr[id].sum = tr[lid].sum + tr[rid].sum;
}
void Add(int &id, int l, int r, int x) {
	if (!id) id = New();
	tr[id].l = l, tr[id].r = r;
	if (l == r) {
		tr[id].sum ++;
		return;
	}
	int mid = tr[id].l + tr[id].r >> 1;
	if (x <= mid) Add(lid, l, mid, x);
	if (x > mid) Add(rid, mid + 1, r, x);
	pushup(id);
}
void Merge(int &x, int y, int l, int r) {
	if (!x) return x = y, void();
	if (!y) return;
	if (l == r) {
		tr[x].sum += tr[y].sum;
		return;
	}
	int mid = tr[x].l + tr[x].r >> 1;
	s1 += 1ll * tr[tr[x].ls].sum * tr[tr[y].rs].sum;
	s2 += 1ll * tr[tr[x].rs].sum * tr[tr[y].ls].sum;
	Merge(tr[x].ls, tr[y].ls, l, mid);
	Merge(tr[x].rs, tr[y].rs, mid + 1, r);
	pushup(x);
}

void dfs(int x) {
	rt[x] = New();
	if (a[x] > 0) Add(rt[x], 1, n, a[x]);
	if (L[x] > 0) {
		dfs(L[x]), s1 = s2 = 0, Merge(rt[x], rt[L[x]], 1, n);
	}
	if (R[x] > 0) {
		dfs(R[x]), s1 = s2 = 0, Merge(rt[x], rt[R[x]], 1, n);
	}
	res += min(s1, s2);
}

int main() {
	n = read(), In(), dfs(1);
	printf("%lld\n", res);
	return 0;
}

CF208E Blood Cousins

link

给定一片 nn 个点的森林,森林中的树都有根。

mm 次询问,每次询问一个点 xx 与多少个点有相同的 KK 级祖先。

1n1051 \le n \le 10^51m1051 \le m \le 10^5

离线做法,首先用倍增对于每个询问求出 xxKK 级祖先,用 vector\tt{vector} 或链表把询问挂在祖先上,问题转换成求一个点有多少个 KK 级儿子的问题。

对每个节点,以深度为下标建立权值线段树,进行 dfs\tt{dfs}。若 dfs\tt{dfs} 到一个节点 xx,在线段树插入当前点的深度并与子节点的线段树合并。然后遍历所有询问,对于询问这个点有多少个 KK 级儿子,在线段树中查询深度为 depx+Kdep_x + K 的节点有多少个即可。别忘了答案要减一。

时间复杂度 Θ(nlogn)\Theta(n \log n)

#include <bits/stdc++.h>
using namespace std;

inline int read() {
	int x = 0, f = 0; char c = 0;
	while (!isdigit(c)) f |= c == '-', c = getchar();
	while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
	return f ? -x : x;
}

#define N 100010
#define pb push_back
#define PII pair<int, int>
#define mp make_pair
#define fi first
#define se second

int n, m, res[N], f[N][22], rt[N], dep[N];
vector<int> e[N], root;
vector<PII> q[N];

struct Segment_tree {
	int l, r, ls, rs, sum;
	#define lid tr[id].ls
	#define rid tr[id].rs
}tr[N * 20];
int cnt = 0;
void Add(int &id, int l, int r, int x) {
	if (!id) id = ++ cnt;
	tr[id].l = l, tr[id].r = r;
	if (l == r) {
		tr[id].sum ++;
		return;
	}
	int mid = tr[id].l + tr[id].r >> 1;
	if (x <= mid) Add(lid, l, mid, x);
	if (x > mid) Add(rid, mid + 1, r, x);
}
int Ask(int id, int l, int r, int x) {
	if (!id) return 0;
	if (l == r) return tr[id].sum;
	int mid = tr[id].l + tr[id].r >> 1;
	if (x <= mid) return Ask(lid, l, mid, x);
	if (x > mid) return Ask(rid, mid + 1, r, x);
}
void Merge(int &x, int y, int l, int r) {
	if (!x || !y) return x += y, void();
	if (l == r) {
		tr[x].sum += tr[y].sum;
		return;
	}
	int mid = tr[x].l + tr[x].r >> 1;
	Merge(tr[x].ls, tr[y].ls, l, mid);
	Merge(tr[x].rs, tr[y].rs, mid + 1, r);
}

void dfs2(int x, int fa) {
	Add(rt[x] = ++ cnt, 1, n, dep[x]);
	for (auto y : e[x]) if (y != fa) {
		dfs2(y, x), Merge(rt[x], rt[y], 1, n);
	}
	for (auto i : q[x]) {
		res[i.fi] = Ask(rt[x], 1, n, dep[x] + i.se) - 1;
	}
}

void dfs1(int x, int fa) {
	dep[x] = dep[fa] + 1;
	f[x][0] = fa;
	for (int i = 1; i <= 20; i ++) {
		f[x][i] = f[f[x][i - 1]][i - 1];
	}
	for (auto y : e[x]) {
		if (y != fa) dfs1(y, x);
	}
}

int main() {
	n = read();
	for (int i = 1; i <= n; i ++) {
		int x = read();
		if (x) e[x].pb(i), e[i].pb(x);
		else root.pb(i);
	}

	for (auto x : root) dfs1(x, 0);

	m = read();
	for (int i = 1; i <= m; i ++) {
		int x = read(), y = read(), fa = x;
		for (int i = 0, j = y; i <= 20; i ++, j >>= 1) {
			if (j & 1) fa = f[fa][i];
		}
		q[fa].pb(mp(i, y));
	}

	for (auto x : root) dfs2(x, 0);

	for (int i = 1; i <= m; i ++) {
		printf("%d ", res[i]);
	} puts("");
	return 0;
}

线段树合并学习笔记
https://ybwa.github.io/p/647e170e/
作者
yb
发布于
2021年11月8日
许可协议