```mermaid graph LR A[Mermaid Keeper] --> B[保持激活] ```

前言

上一篇的树状数组代码简洁,但功能有限——只能做前缀查询。如果你要查任意区间的最大值呢?如果数组初始不是全零呢?

线段树(Segment Tree)能做树状数组能做的一切,而且更多。它的核心理念是:把数组区间递归地一分为二,形成一棵二叉树。每个节点存储一个区间的信息(和、最大值、最小值等),查询时合并若干节点的结果,修改时更新沿途所有节点。

代码比 BIT 长一些(大约 50 行核心代码),但思想清晰,适用范围极广。它是竞赛中最常用的数据结构之一。

问题的本质

递归分治:把区间变成一棵树

给定数组 [3,1,4,1,5,9,2,6][3, 1, 4, 1, 5, 9, 2, 6],线段树长这样:

            [1,8] sum=31
           /            \
      [1,4] sum=9     [5,8] sum=22
      /      \          /      \
  [1,2]   [3,4]    [5,6]   [7,8]
  sum=4   sum=5    sum=14  sum=8
  / \     / \      / \     / \
 [1] [2] [3] [4] [5] [6] [7] [8]
  3   1   4   1   5   9   2   6
  • 每个叶子节点对应数组的一个元素
  • 每个内部节点对应一个区间,存储该区间的聚合值(和/最大值等)
  • 查询 [l,r][l, r]:从根出发,找到完全被覆盖的若干节点,合并结果
  • 修改位置 pp:更新叶子,然后一路向上更新祖先

为什么查询是 O(logn)O(\log n)

一条查询路径最多涉及 logn\log n 层,每层最多选 2 个节点(一个左、一个右),所以总共 O(logn)O(\log n) 个节点。

理论 + 代码

存储

用数组 tree[4*MAXN] 存储完全二叉树。对于节点 node 管辖区间 [lo, hi]

  • 左子节点:2*node,区间 [lo, mid]
  • 右子节点:2*node+1,区间 [mid+1, hi]

单点修改

void update(int node, int lo, int hi, int idx, long long val) {
    if (lo == hi) { tree[node] = val; return; } // ① 叶子节点,直接赋值
    int mid = (lo + hi) / 2;
    if (idx <= mid) update(2*node, lo, mid, idx, val);     // ② 在左子树
    else            update(2*node+1, mid+1, hi, idx, val); // ③ 在右子树
    tree[node] = tree[2*node] + tree[2*node+1]; // ④ 合并子节点的值
}

区间查询

long long query(int node, int lo, int hi, int ql, int qr) {
    if (qr < lo || hi < ql) return 0;           // ① 完全不交,返回零元
    if (ql <= lo && hi <= qr) return tree[node]; // ② 完全包含,直接返回
    int mid = (lo + hi) / 2;
    return query(2*node, lo, mid, ql, qr)       // ③ 递归查询左右子树
         + query(2*node+1, mid+1, hi, ql, qr);
}

Monoid(幺半群)抽象

线段树只要求操作满足:

  1. 结合律(ab)c=a(bc)(a \oplus b) \oplus c = a \oplus (b \oplus c)
  2. 单位元:存在 ee 使得 ae=aa \oplus e = a

加法的单位元是 0,max 的单位元是 -\infty。所以线段树可以维护和、最大值、最小值、GCD 等——只要满足结合律和有单位元。

模拟走一遍

数组 [0,0,16,0,13,0,0,0][0, 0, 16, 0, 13, 0, 0, 0](1-indexed),查询区间 [4, 7)[4, 6]

根节点 [1,8] sum=29
  → 左 [1,4] sum=16: [4,6] 不完全包含
    → 左 [1,2] sum=0: [4,6] 不交 → 返回 0
    → 右 [3,4] sum=16: [4,6] 不完全包含
      → 左 [3,3] sum=16: 不交 → 返回 0
      → 右 [4,4] sum=0: 完全包含 → 返回 0
    → 返回 0 + 0 = 0
  → 右 [5,8] sum=13: [4,6] 不完全包含
    → 左 [5,6] sum=13: 完全包含 → 返回 13
    → 右 [7,8] sum=0: 不交 → 返回 0
  → 返回 13 + 0 = 13
→ 总 = 0 + 13 = 13 ✓

例题

例题 1:TB A59 — RSQ (Range Sum Queries)

题目:长度为 NN 的数组,初始全 0。QQ 个查询:

  • 1 pos x:把 AposA_{pos} 更新为 xx
  • 2 l r:输出 Al+Al+1++Ar1A_l + A_{l+1} + \cdots + A_{r-1} 的和

数据范围1N,Q1051 \le N, Q \le 10^50x10000 \le x \le 1000

—— AtCoder Tessoku Book A59

分析:线段树区间求和模板题。注意查询是 [l,r)[l, r)[l,r1][l, r-1]

代码

#include <cstdio>
using namespace std;

const int MAXN = 400006;
long long tree[MAXN];

void update(int node, int lo, int hi, int idx, long long val) {
    if (lo == hi) { tree[node] = val; return; }
    int mid = (lo + hi) / 2;
    if (idx <= mid) update(2*node, lo, mid, idx, val);
    else            update(2*node+1, mid+1, hi, idx, val);
    tree[node] = tree[2*node] + tree[2*node+1];
}

long long query(int node, int lo, int hi, int ql, int qr) {
    if (qr < lo || hi < ql) return 0;
    if (ql <= lo && hi <= qr) return tree[node];
    int mid = (lo + hi) / 2;
    return query(2*node, lo, mid, ql, qr) + query(2*node+1, mid+1, hi, ql, qr);
}

int main() {
    int N, Q;
    scanf("%d%d", &N, &Q);
    // ① 初始全 0,不需要初始化
    while (Q--) {
        int t;
        scanf("%d", &t);
        if (t == 1) {
            int pos; long long x;
            scanf("%d%lld", &pos, &x);
            update(1, 1, N, pos, x);        // ② 单点更新
        } else {
            int l, r;
            scanf("%d%d", &l, &r);
            printf("%lld\n", query(1, 1, N, l, r - 1)); // ③ 查询 [l, r-1]
        }
    }
    return 0;
}

逐行解析

  • ① 全局变量初始为 0,而题目初始全 0,无需额外初始化。
  • update(1, 1, N, pos, x):从根节点开始,修改位置 pos 的值为 x
  • ③ 查询 [l, r-1](题目说 2 l r 查的是 AlA_lAr1A_{r-1} 的和)。

例题 2:TB A58 — RMQ (Range Maximum Queries)

题目:长度为 NN 的数组,初始全 0。QQ 个查询:

  • 1 pos x:把 AposA_{pos} 更新为 xx
  • 2 l r:输出 Al,Al+1,,Ar1A_l, A_{l+1}, \ldots, A_{r-1} 的最大值

数据范围1N,Q1051 \le N, Q \le 10^50x1090 \le x \le 10^9

—— AtCoder Tessoku Book A58

分析:把线段树的操作从”求和”改成”取最大”即可。单位元从 0 变成 -\infty

代码

#include <cstdio>
#include <algorithm>
using namespace std;

const int MAXN = 400006;
int tree[MAXN];

void update(int node, int lo, int hi, int idx, int val) {
    if (lo == hi) { tree[node] = val; return; }
    int mid = (lo + hi) / 2;
    if (idx <= mid) update(2*node, lo, mid, idx, val);
    else            update(2*node+1, mid+1, hi, idx, val);
    tree[node] = max(tree[2*node], tree[2*node+1]); // ① 取 max 而不是求和
}

int query(int node, int lo, int hi, int ql, int qr) {
    if (qr < lo || hi < ql) return 0;              // ② 全 0 初始,0 是安全的零元
    if (ql <= lo && hi <= qr) return tree[node];
    int mid = (lo + hi) / 2;
    return max(query(2*node, lo, mid, ql, qr),      // ③ max 合并
               query(2*node+1, mid+1, hi, ql, qr));
}

int main() {
    int N, Q;
    scanf("%d%d", &N, &Q);
    while (Q--) {
        int t;
        scanf("%d", &t);
        if (t == 1) {
            int pos, x;
            scanf("%d%d", &pos, &x);
            update(1, 1, N, pos, x);
        } else {
            int l, r;
            scanf("%d%d", &l, &r);
            printf("%d\n", query(1, 1, N, l, r - 1));
        }
    }
    return 0;
}

逐行解析

  • ① 把线段树的合并操作从 +(求和)改成 max(取最大值)。
  • ② 不交区间的零元是 0(因为所有值 0\ge 0,初始全 0)。
  • ③ 合并时用 max 而不是 +。其他结构完全相同。

例题 3:ACL Practice J — Segment Tree

题目:长度 NN 的数组 AAQQ 个查询,支持 3 种操作:

  • 1 X VAX=VA_X = V
  • 2 L R:输出 max(AL,,AR)\max(A_L, \ldots, A_R)
  • 3 X V:找最小 jXj \ge X 使得 AjVA_j \ge V,不存在则输出 N+1N+1

数据范围1N,Q2×1051 \le N, Q \le 2 \times 10^5

—— AtCoder ACL Practice J

分析:查询 1 和 2 是标准线段树操作。查询 3 是线段树的二分搜索——在 [X,N][X, N] 中找到第一个 V\ge V 的位置。

代码

#include <cstdio>
#include <algorithm>
using namespace std;

const int MAXN = 800006;
int tree[MAXN], N;

void update(int node, int lo, int hi, int idx, int val) {
    if (lo == hi) { tree[node] = val; return; }
    int mid = (lo + hi) / 2;
    if (idx <= mid) update(2*node, lo, mid, idx, val);
    else            update(2*node+1, mid+1, hi, idx, val);
    tree[node] = max(tree[2*node], tree[2*node+1]);
}

int query(int node, int lo, int hi, int ql, int qr) {
    if (qr < lo || hi < ql) return 0;
    if (ql <= lo && hi <= qr) return tree[node];
    int mid = (lo + hi) / 2;
    return max(query(2*node, lo, mid, ql, qr),
               query(2*node+1, mid+1, hi, ql, qr));
}

// ④ 线段树上二分:找 [lo,hi] 中第一个 >= val 的位置
int findFirst(int node, int lo, int hi, int start, int val) {
    if (hi < start || tree[node] < val) return N + 1; // ⑤ 不在范围或最大值不够
    if (lo == hi) return lo;                           // ⑥ 找到了
    int mid = (lo + hi) / 2;
    int left = findFirst(2*node, lo, mid, start, val); // ⑦ 先查左子树
    if (left != N + 1) return left;
    return findFirst(2*node+1, mid+1, hi, start, val); // ⑧ 左边没有,查右边
}

int main() {
    int Q;
    scanf("%d%d", &N, &Q);
    for (int i = 1; i <= N; i++) {
        int a;
        scanf("%d", &a);
        update(1, 1, N, i, a);
    }
    while (Q--) {
        int t;
        scanf("%d", &t);
        if (t == 1) {
            int X, V;
            scanf("%d%d", &X, &V);
            update(1, 1, N, X, V);
        } else if (t == 2) {
            int L, R;
            scanf("%d%d", &L, &R);
            printf("%d\n", query(1, 1, N, L, R));
        } else {
            int X, V;
            scanf("%d%d", &X, &V);
            printf("%d\n", findFirst(1, 1, N, X, V));
        }
    }
    return 0;
}

逐行解析

  • ④⑤⑥⑦⑧ 线段树上二分是线段树的高级技巧。如果当前区间的最大值 <val< val,直接排除。否则递归先查左子树(保证找到最小的 jj)。

参考文献

教材讲解 — 競技プログラミングの鉄則 第 8 章

系统练习 — 競技プログラミングの鉄則

实战练习 — 競プロ典型 90 問

模板练习 — ACL Practice Contest


系列索引

第零章 基础工具

第一章 搜索技术

第二章 数学基础

第三章 数据结构

第四章 图论

第五章 动态规划

第六章 贪心

第七章 字符串

第八章 进阶