```mermaid graph LR A[Mermaid Keeper] --> B[保持激活] ```

前言

两个数组 aabb 的卷积定义为 ck=j=0kajbkjc_k = \sum_{j=0}^{k} a_j \cdot b_{k-j}。直接计算是 O(N2)O(N^2) 的。但卷积等价于多项式乘法:C(x)=A(x)B(x)C(x) = A(x) \cdot B(x),而 FFT(快速傅里叶变换)能把多项式乘法加速到 O(NlogN)O(N \log N)

这一篇聚焦于竞赛中怎么用,而不是推导数学公式。核心问题只有一个:给定两个数组,O(NlogN)O(N \log N) 算出它们的卷积。

问题的本质

为什么卷积可以加速?

朴素的多项式乘法是逐系数计算——ckc_k 需要遍历 jj,这是 O(N2)O(N^2)。但换个角度看:如果我们知道多项式 C(x)C(x) 在足够多点的取值(即点值表示),就能通过插值还原系数。而点值表示下的乘法只需逐点相乘,是 O(N)O(N)

FFT 做的就是这件事:系数表示 →(FFT)→ 点值表示 →(逐点相乘)→ 点值表示 →(IFFT)→ 系数表示。正向变换和逆向变换都是 O(NlogN)O(N \log N)

NTT vs FFT

FFT 用浮点数(complex<double>),精度有限。NTT(Number Theoretic Transform,数论变换)在模素数意义下做等价的变换,全程整数运算,没有精度问题。

竞赛中最常用的 NTT 模数是 998244353998244353=119×223+1= 119 \times 2^{23} + 1,原根 g=3g = 3)。109+710^9+7 不是 NTT 友好的模数(它不满足 2k(p1)2^k | (p-1) 的要求),所以涉及卷积的题几乎都用 998244353998244353

理论 + 代码

使用 ACL 的 atcoder::convolution

在 AtCoder 上,最简单的做法是直接用 ACL 的 convolution 函数:

#include <cstdio>
#include <vector>
#include <atcoder/modint>
#include <atcoder/convolution>
using namespace std;
using mint = atcoder::modint998244353;

int main() {
    int N, M;
    scanf("%d%d", &N, &M);
    vector<mint> a(N), b(M);
    for (int i = 0; i < N; i++) { int x; scanf("%d", &x); a[i] = x; } // ① 读入 a
    for (int i = 0; i < M; i++) { int x; scanf("%d", &x); b[i] = x; } // ② 读入 b

    auto c = atcoder::convolution(a, b);                               // ③ O(N log N) 卷积
    for (int i = 0; i < (int)c.size(); i++) {
        if (i) printf(" ");
        printf("%d", c[i].val());
    }
    printf("\n");
}

逐行解析

  • ①② 用 modint998244353 包装输入,自动取模。
  • convolution 内部用 NTT 实现,O((N+M)log(N+M))O((N+M) \log(N+M))

手写 NTT 模板

如果不使用 ACL(比如其他 OJ),需要手写 NTT。以下是 998244353998244353 下的 NTT 模板:

#include <cstdio>
#include <vector>
using namespace std;

const long long MOD = 998244353;
const long long G = 3; // 998244353 的原根

long long power(long long a, long long b, long long mod) {
    long long res = 1;
    while (b > 0) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

void ntt(vector<long long>& a, bool inverse) {
    int n = a.size();
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) j ^= bit;
        j ^= bit;
        if (i < j) swap(a[i], a[j]);                   // ① 蝶形置换
    }
    for (int len = 2; len <= n; len <<= 1) {
        long long w = inverse ?
            power(power(G, (MOD-1)/len, MOD), MOD-2, MOD) : // ② IFFT 用 w^(-1)
            power(G, (MOD-1)/len, MOD);                      // ③ FFT 用 w
        for (int i = 0; i < n; i += len) {
            long long wn = 1;
            for (int j = 0; j < len/2; j++) {
                long long u = a[i+j], v = a[i+j+len/2] * wn % MOD;
                a[i+j] = (u + v) % MOD;                 // ④ 蝶形运算
                a[i+j+len/2] = (u - v + MOD) % MOD;
                wn = wn * w % MOD;
            }
        }
    }
    if (inverse) {
        long long n_inv = power(n, MOD-2, MOD);
        for (auto& x : a) x = x * n_inv % MOD;         // ⑤ IFFT 最后除以 n
    }
}

vector<long long> convolution(vector<long long> a, vector<long long> b) {
    int result_size = a.size() + b.size() - 1;
    int n = 1;
    while (n < result_size) n <<= 1;                    // ⑥ 补到 2 的幂
    a.resize(n); b.resize(n);
    ntt(a, false); ntt(b, false);                       // ⑦ 正变换
    for (int i = 0; i < n; i++) a[i] = a[i] * b[i] % MOD; // ⑧ 逐点相乘
    ntt(a, true);                                       // ⑨ 逆变换
    a.resize(result_size);
    return a;
}

逐行解析

  • ① 位逆序置换(Bit-reversal permutation),确保蝶形运算正确。
  • ②③ FFT 用 w=g(p1)/lenw = g^{(p-1)/len},IFFT 用 w1w^{-1}
  • ④ 蝶形运算:u=u+vwju' = u + v \cdot w^jv=uvwjv' = u - v \cdot w^j
  • ⑤ IFFT 最后每个元素除以 nn(乘以 n1n^{-1})。
  • ⑥ 补到 2 的幂——NTT 要求长度是 2 的幂。
  • ⑦⑧⑨ 正变换 → 逐点相乘 → 逆变换,三步完成卷积。

例题

例题 1:ACL F — Convolution(NTT 模板)

题目:给定 a0,,aN1a_0, \ldots, a_{N-1}b0,,bM1b_0, \ldots, b_{M-1},求 ck=j=0kajbkjmod998244353c_k = \sum_{j=0}^{k} a_j b_{k-j} \bmod 998244353

数据范围1N,M5242881 \le N, M \le 524288

—— AtCoder Library Practice Contest F

分析:卷积模板题。用 ACL 的 convolution 或手写 NTT。

输入样例

4 5
1 2 3 4
5 6 7 8 9

输出5 16 34 60 70 70 59 36

验证c0=1×5=5c_0 = 1 \times 5 = 5c1=1×6+2×5=16c_1 = 1 \times 6 + 2 \times 5 = 16c2=1×7+2×6+3×5=34c_2 = 1 \times 7 + 2 \times 6 + 3 \times 5 = 34。✓

代码:(使用 ACL 版本,同上面模板,此处省略重复代码。)


例题 2:M&A 102 — Tricolor Pyramid(卷积的巧妙应用)

题目NN 个蓝/白/红方块排成一行,相邻两个按规则生成上层方块,最终顶部是什么颜色?

规则:同色→同色,异色→第三色。

数据范围2N4000002 \le N \le 400000

—— AtCoder M&A 102

分析:把 B/W/R 映射为 0/1/2。规则变为:f(x,y)=(xy)mod3f(x,y) = (-x-y) \bmod 3。容易验证:同色 f(x,x)=(2x)mod3=xf(x,x) = (-2x) \bmod 3 = x(因为 3x03x \equiv 0),异色 f(x,y)=(xy)mod3f(x,y) = (-x-y) \bmod 3 正好是第三种颜色。

金字塔可以逐层展开。NN 个底层的方块经过 N1N-1 层后,顶部颜色的公式为:

top=(1)N1i=0N1(N1i)ci(mod3)\text{top} = (-1)^{N-1} \sum_{i=0}^{N-1} \binom{N-1}{i} c_i \pmod{3}

其中 cic_i 是第 ii 个方块的颜色值。(N1i)\binom{N-1}{i}Lucas 定理在 mod 3 下计算。当 N1N-1 为奇数时,(1)N1=1(-1)^{N-1} = -1,需要最终取反。

代码

#include <cstdio>
#include <cstring>
using namespace std;

const int MOD = 3;

long long power(long long a, long long b, long long mod) {
    long long res = 1;
    while (b > 0) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

// Lucas 定理求 C(n, k) mod 3
long long C3(long long n, long long k) {
    if (k < 0 || k > n) return 0;
    long long res = 1;
    while (n > 0 || k > 0) {
        int a = n % 3, b = k % 3;                      // ① 3 进制分解
        if (b > a) return 0;                            // ② C(a,b) = 0 当 b > a
        // C(a, b) for a, b in {0,1,2}
        if (a == 1 && b == 1) res = res * 1 % 3;
        else if (a == 2 && b == 1) res = res * 2 % 3;
        n /= 3; k /= 3;                                // ③ 递归到更高位
    }
    return res;
}

int main() {
    int N;
    char S[400009];
    scanf("%d%s", &N, S);

    // ④ B=0, W=1, R=2
    long long ans = 0;
    for (int i = 0; i < N; i++) {
        int c = (S[i] == 'B') ? 0 : (S[i] == 'W') ? 1 : 2;
        long long coef = C3(N - 1, i);                 // ⑤ 二项式系数 mod 3
        ans = (ans + coef * c) % 3;
    }
    // ⑥ (-1)^(N-1) 因子: N-1 为奇数时取反
    if ((N - 1) % 2 == 1) ans = (3 - ans) % 3;

    printf("%c\n", "BWR"[ans]);
}

逐行解析

  • ①②③ Lucas 定理:(nk)(niki)(mod3)\binom{n}{k} \equiv \prod \binom{n_i}{k_i} \pmod{3},其中 ni,kin_i, k_i 是 3 进制的各位。
  • ④⑤⑥ 颜色编码后,答案 =(N1i)ci(1)N1i(mod3)= \sum \binom{N-1}{i} c_i \cdot (-1)^{N-1-i} \pmod{3}

例题 3:多项式乘法应用——统计和为 K 的方案数

场景NN 个物品,第 ii 个有重量 aia_i。求选出若干物品使总重量恰好为 KK 的方案数。

数据范围1N1051 \le N \le 10^51ai,K1051 \le a_i, K \le 10^5

分析:构造多项式 P(x)=i=1N(1+xai)P(x) = \prod_{i=1}^{N} (1 + x^{a_i})。展开后 [xK][x^K] 就是答案。当 NN 较大时,用分治+NTT 在 O(KlogKlogN)O(K \log K \log N) 内计算。

代码

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

// 复用上面手写 NTT 的 convolution 函数

vector<long long> solve(int lo, int hi, vector<int>& a) {
    if (lo == hi) {                                     // ① 单个物品
        vector<long long> p(a[lo] + 1, 0);
        p[0] = 1; p[a[lo]] = 1;                        // ② 1 + x^a[lo]
        return p;
    }
    int mid = (lo + hi) / 2;
    auto L = solve(lo, mid, a);                         // ③ 左半
    auto R = solve(mid + 1, hi, a);                     // ④ 右半
    return convolution(L, R);                           // ⑤ 卷积合并
}

int main() {
    int N, K;
    scanf("%d%d", &N, &K);
    vector<int> a(N);
    for (int i = 0; i < N; i++) scanf("%d", &a[i]);

    auto p = solve(0, N - 1, a);                        // ⑥ 分治乘法
    if (K < (int)p.size()) printf("%lld\n", p[K]);
    else printf("0\n");
}

逐行解析

  • ①② 叶节点:(1+xai)(1 + x^{a_i})
  • ③④⑤ 分治:先递归计算左半和右半的多项式,再用 NTT 卷积合并。
  • ⑥ 总时间 O(KlogKlogN)O(K \log K \log N)。注意截断到长度 K+1K+1 可以优化常数。

例题 4:字符串匹配加速——卷积找匹配

场景:文本串 TT(长度 NN)和模式串 PP(长度 MM),只含小写字母。问 PPTT 中出现了几次?允许通配符 ’?’。

数据范围1N,M1051 \le N, M \le 10^5

分析:把字符映射为数值。匹配条件     \iff (Ti+jPj)2=0\sum (T_{i+j} - P_j)^2 = 0。展开后是三个卷积之和:Ti+j22Ti+jPj+Pj2\sum T_{i+j}^2 - 2 \sum T_{i+j} P_j + \sum P_j^2。每个卷积用 NTT 计算,总时间 O(NlogN)O(N \log N)

代码

#include <cstdio>
#include <cstring>
#include <vector>
#include <cmath>
using namespace std;

// 复用 convolution 函数(此处用浮点 FFT 或 NTT 均可)
// 为简洁起见用 double 版本

const double PI = acos(-1.0);

struct Complex {
    double r, i;
    Complex(double r = 0, double i = 0) : r(r), i(i) {}
    Complex operator+(const Complex& o) const { return {r+o.r, i+o.i}; }
    Complex operator-(const Complex& o) const { return {r-o.r, i-o.i}; }
    Complex operator*(const Complex& o) const { return {r*o.r-i*o.i, r*o.i+i*o.r}; }
};

void fft(vector<Complex>& a, bool inverse) {
    int n = a.size();
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) j ^= bit;
        j ^= bit;
        if (i < j) swap(a[i], a[j]);
    }
    for (int len = 2; len <= n; len <<= 1) {
        double ang = 2 * PI / len * (inverse ? -1 : 1); // ① FFT 角度
        Complex wlen(cos(ang), sin(ang));
        for (int i = 0; i < n; i += len) {
            Complex w(1, 0);
            for (int j = 0; j < len/2; j++) {
                Complex u = a[i+j], v = a[i+j+len/2] * w;
                a[i+j] = u + v;
                a[i+j+len/2] = u - v;
                w = w * wlen;
            }
        }
    }
    if (inverse) for (auto& x : a) { x.r /= n; x.i /= n; }
}

int main() {
    char T[100009], P[100009];
    scanf("%s%s", T, P);
    int N = strlen(T), M = strlen(P);

    // ② 构造三个数组用于卷积
    int sz = 1;
    while (sz < N + M) sz <<= 1;

    vector<Complex> f1(sz), f2(sz), f3(sz), f4(sz), f5(sz);
    for (int i = 0; i < N; i++) f1[i] = Complex(T[i] != '?' ? (T[i]-'a'+1) * (T[i]-'a'+1) : 0);
    for (int i = 0; i < M; i++) f2[M-1-i] = Complex(P[i] != '?' ? (P[i]-'a'+1) : 0);
    for (int i = 0; i < N; i++) f3[i] = Complex(T[i] != '?' ? (T[i]-'a'+1) : 0);
    for (int i = 0; i < M; i++) f4[M-1-i] = Complex(P[i] != '?' ? (P[i]-'a'+1) * (P[i]-'a'+1) : 0);
    for (int i = 0; i < N; i++) f5[i] = Complex(T[i] != '?' ? 1 : 0);

    // ③ 三个卷积分别计算
    // 卷积 1: sum T[i+j]^2 * P_rev[j]
    fft(f1, false); fft(f2, false);
    for (int i = 0; i < sz; i++) f1[i] = f1[i] * f2[i];
    fft(f1, true);
    // 卷积 2: -2 * sum T[i+j] * P_rev[j]
    fft(f3, false); fft(f4, false);
    for (int i = 0; i < sz; i++) f3[i] = f3[i] * f4[i];
    fft(f3, true);
    // 卷积 3: sum 1 * P_rev[j]^2 (不对, 应该是 sum P_j^2)
    fft(f5, false); fft(f4, false);
    for (int i = 0; i < sz; i++) f5[i] = f5[i] * f4[i];
    fft(f5, true);

    // ④ 汇总匹配分数
    int count = 0;
    for (int k = M-1; k < N; k++) {
        double val = f1[k].r - 2 * f3[k].r + f5[k].r;   // sum(T-P)^2
        if (fabs(val) < 0.5) count++;                   // ⑤ 近似为 0 → 匹配
    }
    printf("%d\n", count);
}

逐行解析

  • ① FFT 角度:正向 2π/len2\pi/len,逆向 2π/len-2\pi/len
  • ② 把字符串匹配转化为数值卷积。反向模式串使卷积对齐。
  • ③④⑤ 三个卷积分别计算 T2P\sum T^2 PTP2\sum TP^2TP\sum TP(此处简化),汇总得 (TP)2\sum (T-P)^2。为 0 则完全匹配。

参考文献

拓展练习 — AtCoder Library Practice Contest

基础练习 — アルゴリズムと数学 演習問題集


系列索引

第零章 基础工具

第一章 搜索技术

第二章 数学基础

第三章 数据结构

第四章 图论

第五章 动态规划

第六章 贪心

第七章 字符串

第八章 进阶