```mermaid graph LR A[Mermaid Keeper] --> B[保持激活] ```

前言

前面的 DP 都是处理线性序列或网格上的问题。这一篇介绍两种特殊的状态设计:

树形 DP——状态定义在树的节点上,按子树到根的顺序递推。04-08 树上问题 里讲的换根 DP 其实就是树形 DP 的一种。这篇会讲更基础的模型:最大独立集、删边方案数。

数位 DP——统计 11NN 中满足某种条件的数有多少个。NN 可以高达 101810^{18},但只要按位分解,状态数只有 O(logN×状态维度)O(\log N \times 状态维度)

问题的本质

树形 DP:为什么是”从叶到根”?

树有一个天然的计算顺序:先算完所有子树,再算当前节点。因为子树之间互不干扰(无后效性),每个子树可以独立求解,然后把结果汇总到父节点。

状态通常是 dp[v][某个维度]dp[v][\text{某个维度}] = 以 vv 为根的子树中,满足某种条件的最优值或方案数。转移时遍历 vv 的所有子节点,把子节点的 dpdp 值合并到 vv

数位 DP:为什么可以”逐位处理”?

一个数 NN 可以写成十进制序列 dkdk1d1d0d_k d_{k-1} \ldots d_1 d_0。数位 DP 的思路是:从高位到低位,逐位决定填什么数字。关键状态是”当前是否已经严格小于 NN“——如果已经小于,后面的位可以随便填;如果还没小于,就只能填 di\le d_i 的数字。

理论 + 代码

树形 DP:最大独立集

给定一棵树,选若干节点,要求相邻节点不能同时被选。求最多选几个节点?

状态dp[v][0]dp[v][0] = 不选节点 vv 时,子树 vv 中的最大独立集大小;dp[v][1]dp[v][1] = 选节点 vv 时的最大独立集大小。

转移

  • dp[v][0]=cmax(dp[c][0],dp[c][1])dp[v][0] = \sum_{c} \max(dp[c][0], dp[c][1])(不选 vv,子节点可选可不选)
  • dp[v][1]=1+cdp[c][0]dp[v][1] = 1 + \sum_{c} dp[c][0](选了 vv,子节点不能选)
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

vector<int> adj[100009];
int dp[100009][2]; // dp[v][0]=不选v, dp[v][1]=选v

void dfs(int v, int p) {
    dp[v][0] = 0;
    dp[v][1] = 1;                                   // ① 选 v 自身
    for (int c : adj[v]) {
        if (c == p) continue;
        dfs(c, v);
        dp[v][0] += max(dp[c][0], dp[c][1]);         // ② 子节点可选可不选
        dp[v][1] += dp[c][0];                         // ③ 选了 v,子节点不能选
    }
}

int main() {
    int N;
    scanf("%d", &N);
    for (int i = 0; i < N - 1; i++) {
        int a, b; scanf("%d%d", &a, &b);
        adj[a].push_back(b); adj[b].push_back(a);
    }
    dfs(1, 0);
    printf("%d\n", max(dp[1][0], dp[1][1]));
}

逐行解析

  • ① 选了节点 vv,独立集大小至少 1(vv 自身)。
  • ② 不选 vv 时,每个子节点 cc 可以选也可以不选,取较大值累加。
  • ③ 选了 vv 时,子节点 cc 一定不能选,所以只能加 dp[c][0]dp[c][0]

数位 DP:统计满足条件的数

统计 11NN 中,满足某种条件的数有多少个?

通用框架

// digit DP 框架
string S = to_string(N); // N 的十进制表示
int len = S.size();
long long dp[len+1][2]; // dp[pos][tight]

// dp[pos][tight] = 处理到第 pos 位时的方案数
// tight = 1 表示前面每一位都和 N 相同,当前位受限
// tight = 0 表示已经比 N 小了,当前位可以填 0~9

状态dp[pos][tight]dp[pos][tight] = 处理到第 pospos 位时的合法方案数。tight=1tight=1 表示”还没小于 NN”,tight=0tight=0 表示”已经小于 NN”。

转移:枚举当前位置填的数字 dd。如果 tight=1tight=1,则 dS[pos]d \le S[pos];如果 tight=0tight=0,则 dd 可以是 090 \sim 9。新的 tighttight 取决于是否还”紧贴” NN

例题

例题 1:T90 026 — Independent Set on a Tree(★4)

题目NN 个节点的树(NN 为偶数)。选出 N/2N/2 个节点,使得没有任何两个相邻节点同时被选。输出任意一个合法方案。

数据范围2N1052 \le N \le 10^5

—— AtCoder Typical 90 026

分析:树一定是二分图——可以把所有节点分成两组(按深度奇偶),同一组内没有相邻节点。所以答案就是深度为偶数的 N/2N/2 个节点,或深度为奇数的 N/2N/2 个节点(取较多的一组)。

不需要真的做 DP,二分图染色即可。

代码

#include <cstdio>
#include <vector>
using namespace std;

vector<int> adj[100009];
int color[100009]; // 0 或 1

void dfs(int v, int p, int c) {
    color[v] = c;
    for (int u : adj[v])
        if (u != p) dfs(u, v, 1 - c); // ① 交替染色
}

int main() {
    int N;
    scanf("%d", &N);
    for (int i = 0; i < N - 1; i++) {
        int a, b; scanf("%d%d", &a, &b);
        adj[a].push_back(b); adj[b].push_back(a);
    }
    dfs(1, 0, 0); // ② 从根开始染色
    vector<int> g0, g1;
    for (int i = 1; i <= N; i++) {
        if (color[i] == 0) g0.push_back(i);
        else g1.push_back(i);
    }
    // ③ 取较大的一组,输出前 N/2 个
    if (g0.size() >= g1.size()) {
        for (int i = 0; i < N / 2; i++) {
            if (i) printf(" ");
            printf("%d", g0[i]);
        }
    } else {
        for (int i = 0; i < N / 2; i++) {
            if (i) printf(" ");
            printf("%d", g1[i]);
        }
    }
    printf("\n");
}

逐行解析

  • ① DFS 交替染色:父节点染 0,子节点染 1,孙节点染 0……
  • ② 从节点 1 开始,染 0。
  • ③ 因为树是二分图,两组节点各自内部没有边相连。NN 是偶数,两组各至少有 N/2N/2 个。

验证N=4N=4,边 (1,2),(2,3),(2,4)(1,2), (2,3), (2,4)。DFS 染色:1→0, 2→1, 3→0, 4→0。组 0 = {1,3,4}(3 个),组 1 = {2}(1 个)。输出组 0 的前 2 个:3 4。样例答案也是 3 4。✓


例题 2:T90 073 — We Need Both a and b(★5,树上删边方案数)

题目NN 个节点的树,每个节点标 ‘a’ 或 ‘b’。删掉 0 条或多条边后,要求每个连通分量都同时包含 ‘a’ 和 ‘b’。求方案数。

数据范围2N1052 \le N \le 10^5

—— AtCoder Typical 90 073

分析:树形 DP + 容斥。

dp[v][0]dp[v][0] = 子树 vv 中,vv 所在连通分量只有 a 的方案数;dp[v][1]dp[v][1] = 只有 b 的方案数;dp[v][2]dp[v][2] = 同时有 a 和 b 的方案数。

对节点 vv 和子节点 cc,考虑边 (v,c)(v, c)

  • 保留边:cc 所在连通分量和 vv 合并
  • 删去边:cc 所在连通分量必须已经满足条件(即 dp[c][2]dp[c][2]

转移(新 dp[v][]dp'[v][*] 从旧 dp[v][]dp[v][*]dp[c][]dp[c][*] 推导):

合并时,vv 的连通分量和 cc 的连通分量组合:

  • (只有a) + (只有a) → (只有a)
  • (只有b) + (只有b) → (只有b)
  • (只有a) + (只有b) → (有ab)
  • 任何 + (有ab) → (有ab)

代码

#include <cstdio>
#include <vector>
using namespace std;

const int MOD = 1000000007;
vector<int> adj[100009];
char c[100009];
long long dp[100009][3]; // 0=only_a, 1=only_b, 2=both

void dfs(int v, int p) {
    if (c[v] == 'a') { dp[v][0] = 1; dp[v][1] = 0; }
    else { dp[v][0] = 0; dp[v][1] = 1; } // ① 初始
    dp[v][2] = 0;

    for (int u : adj[v]) {
        if (u == p) continue;
        dfs(u, v);
        long long n0 = dp[v][0] * (dp[u][0] + dp[u][2]) % MOD; // ② v只有a
        long long n1 = dp[v][1] * (dp[u][1] + dp[u][2]) % MOD; // ③ v只有b
        long long n2 = (dp[v][2] * (dp[u][0] + dp[u][1] + dp[u][2]) % MOD // ④ v有ab
                      + dp[v][0] * dp[u][1] % MOD   // ⑤ v的a + u的b
                      + dp[v][1] * dp[u][0] % MOD   // ⑥ v的b + u的a
                      ) % MOD;
        dp[v][0] = n0;
        dp[v][1] = n1;
        dp[v][2] = n2;
    }
}

int main() {
    int N;
    scanf("%d", &N);
    for (int i = 1; i <= N; i++) {
        char ch; scanf(" %c", &ch); c[i] = ch;
    }
    for (int i = 0; i < N - 1; i++) {
        int a, b; scanf("%d%d", &a, &b);
        adj[a].push_back(b); adj[b].push_back(a);
    }
    dfs(1, 0);
    printf("%lld\n", dp[1][2]); // ⑦ 根节点的"有ab"方案数
}

逐行解析

  • ① 叶节点初始化:如果 vv 是 ‘a’,dp[v][0]=1;如果是 ‘b’,dp[v][1]=1dp[v][2]=0(单个节点不可能同时有 ab)。
  • ② 保留边 + uu 只有 a:vv 的连通分量仍只有 a。uu 的”有ab”情况也可以保留(合并后 vv 仍只有 a,因为 uu 那边没有 b)。
  • ⑤⑥ 关键:vv 的连通分量只有 a,uu 的只有 b,合并后变成”有 ab”。
  • ⑦ 答案是根节点”有 ab”的方案数。

例题 3:T90 082 — Counting Numbers(★3,按位数分段求和)

题目x=LRx×(十进制位数 of x)\sum_{x=L}^{R} x \times (\text{十进制位数 of } x)L,R1018L, R \le 10^{18}。答案对 109+710^9+7 取模。

—— AtCoder Typical 90 082

分析:不需要严格的数位 DP。按位数分段:1 位数(1-9),2 位数(10-99),3 位数(100-999),……对每一段,x=lrx=(l+r)(rl+1)/2\sum_{x=l}^{r} x = (l+r)(r-l+1)/2,乘以位数即可。

代码

#include <cstdio>
using namespace std;

const long long MOD = 1000000007;

long long solve(long long N) {
    if (N <= 0) return 0;
    long long ans = 0;
    long long p10 = 1; // 10^k(当前位数的起点)
    for (int d = 1; p10 <= N; d++) { // d = 位数
        long long lo = p10;              // ① 当前段最小值
        long long hi = min(N, p10 * 10 - 1); // ② 当前段最大值
        long long cnt = (hi - lo + 1) % MOD; // ③ 这个段有几个数
        long long sum = ((lo + hi) % MOD) * cnt % MOD * ((MOD + 1) / 2) % MOD; // ④ 等差数列求和
        ans = (ans + sum % MOD * d % MOD) % MOD; // ⑤ 乘以位数
        p10 *= 10;
    }
    return ans;
}

int main() {
    long long L, R;
    scanf("%lld%lld", &L, &R);
    long long ans = (solve(R) - solve(L - 1) % MOD + MOD) % MOD; // ⑥ 前缀和相减
    printf("%lld\n", ans);
}

逐行解析

  • ①② 当前位数的段从 10d110^{d-1}min(N,10d1)\min(N, 10^d - 1)。比如 3 位数从 100 到 min(N,999)\min(N, 999)
  • ④ 等差数列求和 (l+r)(rl+1)/2(l+r)(r-l+1)/2。除以 2 用乘以逆元 (MOD+1)/2=5×108+4(MOD+1)/2 = 5 \times 10^8 + 4
  • ⑤ 每个数贡献的字符数 = 数字本身 × 位数。这里 sum 是数字之和,乘以位数 dd
  • LR=1R1L1\sum_{L}^{R} = \sum_1^R - \sum_1^{L-1}

验证L=3,R=5L=3, R=5

solve(5):1 位数段 lo=1, hi=5, sum=(1+5)*5/2=15, 乘以 1 = 15。 solve(2):sum=(1+2)*2/2=3, 乘以 1 = 3。 答案 = 15 - 3 = 12。✓(3+3+3+4+4+4+4+5+5+5+5+5 = 12 个字符。)


例题 4:T90 005 — Restricted Digits(★7,数位 DP 模板)

题目:统计 11NN 中,所有数位都来自集合 S={d1,d2,,dK}S = \{d_1, d_2, \ldots, d_K\} 的数有多少个。答案对 109+710^9+7 取模。

数据范围1N10181 \le N \le 10^{18}1K101 \le K \le 10

—— AtCoder Typical 90 005

分析:标准数位 DP。从高位到低位逐位决策,状态是”当前处理到第几位”和”是否已经严格小于 NN”。

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int MOD = 1000000007;
long long dp[20][2]; // dp[pos][tight]
int digit[20], D[11];
int len, K;                                          // K = 允许的数字个数

long long dfs(int pos, bool tight) {
    if (pos == len) return 1;                           // ① 填完了,1 个合法数
    if (dp[pos][tight] >= 0) return dp[pos][tight];    // ② 记忆化

    long long res = 0;
    int limit = tight ? digit[pos] : 9;                // ③ 当前位上限
    for (int i = 0; i < K; i++) {
        if (D[i] > limit) continue;                     // ④ 超出上限,跳过
        bool ntight = tight && (D[i] == limit);         // ⑤ 新的 tight 状态
        res = (res + dfs(pos + 1, ntight)) % MOD;
    }
    return dp[pos][tight] = res;
}

int main() {
    long long N;                                        // ⑥ 上界 N
    scanf("%lld%d", &N, &K);
    for (int i = 0; i < K; i++) scanf("%d", &D[i]);    // ⑦ 允许的数字集合
    sort(D, D + K);                                     // ⑧ 排序便于理解

    // ⑨ 把 N 转成数字数组
    len = 0;
    long long tmp = N;
    while (tmp > 0) { digit[len++] = tmp % 10; tmp /= 10; }
    reverse(digit, digit + len);                        // 高位在前

    memset(dp, -1, sizeof(dp));                         // ⑩ 初始化记忆化数组
    printf("%lld\n", dfs(0, true));                      // ⑪ 从第 0 位开始,初始 tight=true
}

逐行解析

  • ① 递归边界:所有位都填完了,得到 1 个合法数。
  • ② 记忆化——同一个 (pos,tight)(pos, tight) 只算一次。
  • ③ 如果 tight=1tight=1,当前位最大填 digit[pos]digit[pos]NN 的这一位);否则可以填 0-9。
  • ④ 只能填集合 SS 中的数字。如果 SS 中的某个数超过上限,跳过。
  • ⑤ 新的 tighttight:只有之前一直紧贴 NN 且当前位也刚好等于上限时,才继续紧贴。
  • ⑨ 把 NN 的十进制表示存入 digit[] 数组,高位在前。
  • -1 表示未计算。每次 dfs 会自动缓存结果。
  • ⑪ 从最高位开始,初始 tight=truetight = \text{true}(还没小于 NN)。

注意:这个递归模板可以适配各种数位 DP 题目——只需要修改 dfs 中的枚举逻辑和终止条件。

参考文献

实战练习 — 競プロ典型 90 問


系列索引

第零章 基础工具

第一章 搜索技术

第二章 数学基础

第三章 数据结构

第四章 图论

第五章 动态规划

第六章 贪心

第七章 字符串

第八章 进阶