LogoCSP Wiki By Yundou
3 DS

线段树

引入

线段树是算法竞赛中常用的用来维护 区间信息 的数据结构。

线段树可以在 O(logN)O(\log N) 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

线段树的基本结构与建树

过程

线段树将每个长度不为 11 的区间划分成左右两个区间递归求解,把整个线段划分为一个树形结构,通过合并左右两区间信息来求得该区间的信息。这种数据结构可以方便的进行大部分的区间操作。

有个大小为 55 的数组 a={10,11,12,13,14}a=\{10,11,12,13,14\},要将其转化为线段树,有以下做法:设线段树的根节点编号为 11,用数组 dd 来保存我们的线段树,did_i 用来保存线段树上编号为 ii 的节点的值(这里每个节点所维护的值就是这个节点所表示的区间总和)。

我们先给出这棵线段树的形态,如图所示:

图片描述

图中每个节点中用红色字体标明的区间,表示该节点管辖的 aa 数组上的位置区间。如 d1d_1 所管辖的区间就是 [1,5][1,5]a1,a2,,a5a_1,a_2, \cdots ,a_5),即 d1d_1 所保存的值是 a1+a2++a5a_1+a_2+ \cdots +a_5d1=60d_1=60 表示的是 a1+a2++a5=60a_1+a_2+ \cdots +a_5=60

通过观察不难发现,did_i 的左儿子节点就是 d2×id_{2\times i}did_i 的右儿子节点就是 d2×i+1d_{2\times i+1}。如果 did_i 表示的是区间 [s,t][s,t](即 di=as+as+1++atd_i=a_s+a_{s+1}+ \cdots +a_t)的话,那么 did_i 的左儿子节点表示的是区间 [s,s+t2][ s, \frac{s+t}{2} ]did_i 的右儿子表示的是区间 [s+t2+1,t][ \frac{s+t}{2} +1,t ]

在实现时,我们考虑递归建树。设当前的根节点为 pp,如果根节点管辖的区间长度已经是 11,则可以直接根据 aa 数组上相应位置的值初始化该节点。否则我们将该区间从中点处分割为两个子区间,分别进入左右子节点递归建树,最后合并两个子节点的信息。

实现

此处给出代码实现,可参考注释理解:

void build(int s, int t, int p) {
  // 对 [s,t] 区间建立线段树,当前根的编号为 p
  if (s == t) {
    d[p] = a[s];
    return;
  }
  int m = s + ((t - s) >> 1);
  // 移位运算符的优先级小于加减法,所以加上括号
  // 如果写成 (s + t) >> 1 可能会超出 int 范围
  build(s, m, p * 2), build(m + 1, t, p * 2 + 1);
  // 递归对左右区间建树
  d[p] = d[p * 2] + d[(p * 2) + 1];
}

关于线段树的空间:如果采用堆式存储(2p2ppp 的左儿子,2p+12p+1pp 的右儿子),若有 nn 个叶子结点,则 d 数组的范围最大为 2logn+12^{\left\lceil\log{n}\right\rceil+1}

分析:容易知道线段树的深度是 logn\left\lceil\log{n}\right\rceil 的,则在堆式储存情况下叶子节点(包括无用的叶子节点)数量为 2logn2^{\left\lceil\log{n}\right\rceil} 个,又由于其为一棵完全二叉树,则其总节点个数 2logn+112^{\left\lceil\log{n}\right\rceil+1}-1。当然如果你懒得计算的话可以直接把数组长度设为 4n4n,因为 2logn+11n\frac{2^{\left\lceil\log{n}\right\rceil+1}-1}{n} 的最大值在 n=2x+1(xN+)n=2^{x}+1(x\in N_{+}) 时取到,此时节点数为 2logn+11=2x+21=4n52^{\left\lceil\log{n}\right\rceil+1}-1=2^{x+2}-1=4n-5

而堆式存储存在无用的叶子节点,可以考虑使用内存池管理线段树节点,每当需要新建节点时从池中获取。自底向上考虑,必有每两个底层节点合并为一个上层节点,因此可以类似哈夫曼树地证明,如果有 nn 个叶子节点,这样的线段树总共有 2n12n-1 个节点。其空间效率优于堆式存储,并且是可能的最优情况。

线段树的区间查询

过程

区间查询,比如求区间 [l,r][l,r] 的总和(即 al+al+1++ara_l+a_{l+1}+ \cdots +a_r)、求区间最大值/最小值等操作。

图片描述

仍然以最开始的图为例,如果要查询区间 [1,5][1,5] 的和,那直接获取 d1d_1 的值(6060)即可。

如果要查询的区间为 [3,5][3,5],此时就不能直接获取区间的值,但是 [3,5][3,5] 可以拆成 [3,3][3,3][4,5][4,5],可以通过合并这两个区间的答案来求得这个区间的答案。

一般地,如果要查询的区间是 [l,r][l,r],则可以将其拆成最多为 O(logn)O(\log n)极大 的区间,合并这些区间即可求出 [l,r][l,r] 的答案。

实现

此处给出代码实现,可参考注释理解:

int getsum(int l, int r, int s, int t, int p) {
  // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
  if (l <= s && t <= r)
    return d[p];  // 当前区间为询问区间的子集时直接返回当前区间的和
  int m = s + ((t - s) >> 1), sum = 0;
  if (l <= m) sum += getsum(l, r, s, m, p * 2);
  // 如果左儿子代表的区间 [s, m] 与询问区间有交集, 则递归查询左儿子
  if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
  // 如果右儿子代表的区间 [m + 1, t] 与询问区间有交集, 则递归查询右儿子
  return sum;
}

线段树的区间修改与懒惰标记

过程

如果要求修改区间 [l,r][l,r],把所有包含在区间 [l,r][l,r] 中的节点都遍历一次、修改一次,时间复杂度无法承受。我们这里要引入一个叫做 「懒惰标记」 的东西。

懒惰标记,简单来说,就是通过延迟对节点信息的更改,从而减少可能不必要的操作次数。每次执行修改时,我们通过打标记的方法表明该节点对应的区间在某一次操作中被更改,但不更新该节点的子节点的信息。实质性的修改则在下一次访问带有标记的节点时才进行。

仍然以最开始的图为例,我们将执行若干次给区间内的数加上一个值的操作。我们现在给每个节点增加一个 tit_i,表示该节点带的标记值。

最开始时的情况是这样的(为了节省空间,这里不再展示每个节点管辖的区间):

图片描述

现在我们准备给 [3,5][3,5] 上的每个数都加上 55。根据前面区间查询的经验,我们很快找到了两个极大区间 [3,3][3,3][4,5][4,5](分别对应线段树上的 33 号点和 55 号点)。

我们直接在这两个节点上进行修改,并给它们打上标记:

图片描述

我们发现,33 号节点的信息虽然被修改了(因为该区间管辖两个数,所以 d3d_3 加上的数是 5×2=105 \times 2=10),但它的两个子节点却还没更新,仍然保留着修改之前的信息。不过不用担心,虽然修改目前还没进行,但当我们要查询这两个子节点的信息时,我们会利用标记修改这两个子节点的信息,使查询的结果依旧准确。

接下来我们查询一下 [4,4][4,4] 区间上各数字的和。

我们通过递归找到 [4,5][4,5] 区间,发现该区间并非我们的目标区间,且该区间上还存在标记。这时候就到标记下放的时间了。我们将该区间的两个子区间的信息更新,并清除该区间上的标记。

图片描述

现在 6677 两个节点的值变成了最新的值,查询的结果也是准确的。

实现

接下来给出在存在标记的情况下,区间修改和查询操作的参考实现。

区间修改(区间加上某个值):

// [l, r] 为修改区间, c 为被修改的元素的变化量, [s, t] 为当前节点包含的区间, p
// 为当前节点的编号
void update(int l, int r, int c, int s, int t, int p) {
  // 当前区间为修改区间的子集时直接修改当前节点的值,然后打标记,结束修改
  if (l <= s && t <= r) {
    d[p] += (t - s + 1) * c, b[p] += c;
    return;
  }
  int m = s + ((t - s) >> 1);
  if (b[p] && s != t) {
    // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
    d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m);
    b[p * 2] += b[p], b[p * 2 + 1] += b[p];  // 将标记下传给子节点
    b[p] = 0;                                // 清空当前节点的标记
  }
  if (l <= m) update(l, r, c, s, m, p * 2);
  if (r > m) update(l, r, c, m + 1, t, p * 2 + 1);
  d[p] = d[p * 2] + d[p * 2 + 1];
}

区间查询(区间求和):

int getsum(int l, int r, int s, int t, int p) {
  // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
  if (l <= s && t <= r) return d[p];
  // 当前区间为询问区间的子集时直接返回当前区间的和
  int m = s + ((t - s) >> 1);
  if (b[p]) {
    // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
    d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m);
    b[p * 2] += b[p], b[p * 2 + 1] += b[p];  // 将标记下传给子节点
    b[p] = 0;                                // 清空当前节点的标记
  }
  int sum = 0;
  if (l <= m) sum = getsum(l, r, s, m, p * 2);
  if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
  return sum;
}

如果你是要实现区间修改为某一个值而不是加上某一个值的话,代码如下:

void update(int l, int r, int c, int s, int t, int p) {
  if (l <= s && t <= r) {
    d[p] = (t - s + 1) * c, b[p] = c, v[p] = 1;
    return;
  }
  int m = s + ((t - s) >> 1);
  // 额外数组储存是否修改值
  if (v[p]) {
    d[p * 2] = b[p] * (m - s + 1), d[p * 2 + 1] = b[p] * (t - m);
    b[p * 2] = b[p * 2 + 1] = b[p];
    v[p * 2] = v[p * 2 + 1] = 1;
    v[p] = 0;
  }
  if (l <= m) update(l, r, c, s, m, p * 2);
  if (r > m) update(l, r, c, m + 1, t, p * 2 + 1);
  d[p] = d[p * 2] + d[p * 2 + 1];
}
 
int getsum(int l, int r, int s, int t, int p) {
  if (l <= s && t <= r) return d[p];
  int m = s + ((t - s) >> 1);
  if (v[p]) {
    d[p * 2] = b[p] * (m - s + 1), d[p * 2 + 1] = b[p] * (t - m);
    b[p * 2] = b[p * 2 + 1] = b[p];
    v[p * 2] = v[p * 2 + 1] = 1;
    v[p] = 0;
  }
  int sum = 0;
  if (l <= m) sum = getsum(l, r, s, m, p * 2);
  if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
  return sum;
}

例题:

On this page