前言
两个数组 和 的卷积定义为 。直接计算是 的。但卷积等价于多项式乘法:,而 FFT(快速傅里叶变换)能把多项式乘法加速到 。
这一篇聚焦于竞赛中怎么用,而不是推导数学公式。核心问题只有一个:给定两个数组, 算出它们的卷积。
问题的本质
为什么卷积可以加速?
朴素的多项式乘法是逐系数计算—— 需要遍历 ,这是 。但换个角度看:如果我们知道多项式 在足够多点的取值(即点值表示),就能通过插值还原系数。而点值表示下的乘法只需逐点相乘,是 。
FFT 做的就是这件事:系数表示 →(FFT)→ 点值表示 →(逐点相乘)→ 点值表示 →(IFFT)→ 系数表示。正向变换和逆向变换都是 。
NTT vs FFT
FFT 用浮点数(complex<double>),精度有限。NTT(Number Theoretic Transform,数论变换)在模素数意义下做等价的变换,全程整数运算,没有精度问题。
竞赛中最常用的 NTT 模数是 (,原根 )。 不是 NTT 友好的模数(它不满足 的要求),所以涉及卷积的题几乎都用 。
理论 + 代码
使用 ACL 的 atcoder::convolution
在 AtCoder 上,最简单的做法是直接用 ACL 的 convolution 函数:
#include <cstdio>
#include <vector>
#include <atcoder/modint>
#include <atcoder/convolution>
using namespace std;
using mint = atcoder::modint998244353;
int main() {
int N, M;
scanf("%d%d", &N, &M);
vector<mint> a(N), b(M);
for (int i = 0; i < N; i++) { int x; scanf("%d", &x); a[i] = x; } // ① 读入 a
for (int i = 0; i < M; i++) { int x; scanf("%d", &x); b[i] = x; } // ② 读入 b
auto c = atcoder::convolution(a, b); // ③ O(N log N) 卷积
for (int i = 0; i < (int)c.size(); i++) {
if (i) printf(" ");
printf("%d", c[i].val());
}
printf("\n");
}
逐行解析:
- ①② 用
modint998244353包装输入,自动取模。 - ③
convolution内部用 NTT 实现,。
手写 NTT 模板
如果不使用 ACL(比如其他 OJ),需要手写 NTT。以下是 下的 NTT 模板:
#include <cstdio>
#include <vector>
using namespace std;
const long long MOD = 998244353;
const long long G = 3; // 998244353 的原根
long long power(long long a, long long b, long long mod) {
long long res = 1;
while (b > 0) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
void ntt(vector<long long>& a, bool inverse) {
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]); // ① 蝶形置换
}
for (int len = 2; len <= n; len <<= 1) {
long long w = inverse ?
power(power(G, (MOD-1)/len, MOD), MOD-2, MOD) : // ② IFFT 用 w^(-1)
power(G, (MOD-1)/len, MOD); // ③ FFT 用 w
for (int i = 0; i < n; i += len) {
long long wn = 1;
for (int j = 0; j < len/2; j++) {
long long u = a[i+j], v = a[i+j+len/2] * wn % MOD;
a[i+j] = (u + v) % MOD; // ④ 蝶形运算
a[i+j+len/2] = (u - v + MOD) % MOD;
wn = wn * w % MOD;
}
}
}
if (inverse) {
long long n_inv = power(n, MOD-2, MOD);
for (auto& x : a) x = x * n_inv % MOD; // ⑤ IFFT 最后除以 n
}
}
vector<long long> convolution(vector<long long> a, vector<long long> b) {
int result_size = a.size() + b.size() - 1;
int n = 1;
while (n < result_size) n <<= 1; // ⑥ 补到 2 的幂
a.resize(n); b.resize(n);
ntt(a, false); ntt(b, false); // ⑦ 正变换
for (int i = 0; i < n; i++) a[i] = a[i] * b[i] % MOD; // ⑧ 逐点相乘
ntt(a, true); // ⑨ 逆变换
a.resize(result_size);
return a;
}
逐行解析:
- ① 位逆序置换(Bit-reversal permutation),确保蝶形运算正确。
- ②③ FFT 用 ,IFFT 用 。
- ④ 蝶形运算:,。
- ⑤ IFFT 最后每个元素除以 (乘以 )。
- ⑥ 补到 2 的幂——NTT 要求长度是 2 的幂。
- ⑦⑧⑨ 正变换 → 逐点相乘 → 逆变换,三步完成卷积。
例题
例题 1:ACL F — Convolution(NTT 模板)
题目:给定 和 ,求 。
数据范围:
分析:卷积模板题。用 ACL 的 convolution 或手写 NTT。
输入样例:
4 5
1 2 3 4
5 6 7 8 9
输出:5 16 34 60 70 70 59 36
验证:,,。✓
代码:(使用 ACL 版本,同上面模板,此处省略重复代码。)
例题 2:M&A 102 — Tricolor Pyramid(卷积的巧妙应用)
题目: 个蓝/白/红方块排成一行,相邻两个按规则生成上层方块,最终顶部是什么颜色?
规则:同色→同色,异色→第三色。
数据范围:
分析:把 B/W/R 映射为 0/1/2。规则变为:。容易验证:同色 (因为 ),异色 正好是第三种颜色。
金字塔可以逐层展开。 个底层的方块经过 层后,顶部颜色的公式为:
其中 是第 个方块的颜色值。 用 Lucas 定理在 mod 3 下计算。当 为奇数时,,需要最终取反。
代码:
#include <cstdio>
#include <cstring>
using namespace std;
const int MOD = 3;
long long power(long long a, long long b, long long mod) {
long long res = 1;
while (b > 0) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
// Lucas 定理求 C(n, k) mod 3
long long C3(long long n, long long k) {
if (k < 0 || k > n) return 0;
long long res = 1;
while (n > 0 || k > 0) {
int a = n % 3, b = k % 3; // ① 3 进制分解
if (b > a) return 0; // ② C(a,b) = 0 当 b > a
// C(a, b) for a, b in {0,1,2}
if (a == 1 && b == 1) res = res * 1 % 3;
else if (a == 2 && b == 1) res = res * 2 % 3;
n /= 3; k /= 3; // ③ 递归到更高位
}
return res;
}
int main() {
int N;
char S[400009];
scanf("%d%s", &N, S);
// ④ B=0, W=1, R=2
long long ans = 0;
for (int i = 0; i < N; i++) {
int c = (S[i] == 'B') ? 0 : (S[i] == 'W') ? 1 : 2;
long long coef = C3(N - 1, i); // ⑤ 二项式系数 mod 3
ans = (ans + coef * c) % 3;
}
// ⑥ (-1)^(N-1) 因子: N-1 为奇数时取反
if ((N - 1) % 2 == 1) ans = (3 - ans) % 3;
printf("%c\n", "BWR"[ans]);
}
逐行解析:
- ①②③ Lucas 定理:,其中 是 3 进制的各位。
- ④⑤⑥ 颜色编码后,答案 。
例题 3:多项式乘法应用——统计和为 K 的方案数
场景: 个物品,第 个有重量 。求选出若干物品使总重量恰好为 的方案数。
数据范围:,
分析:构造多项式 。展开后 就是答案。当 较大时,用分治+NTT 在 内计算。
代码:
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
// 复用上面手写 NTT 的 convolution 函数
vector<long long> solve(int lo, int hi, vector<int>& a) {
if (lo == hi) { // ① 单个物品
vector<long long> p(a[lo] + 1, 0);
p[0] = 1; p[a[lo]] = 1; // ② 1 + x^a[lo]
return p;
}
int mid = (lo + hi) / 2;
auto L = solve(lo, mid, a); // ③ 左半
auto R = solve(mid + 1, hi, a); // ④ 右半
return convolution(L, R); // ⑤ 卷积合并
}
int main() {
int N, K;
scanf("%d%d", &N, &K);
vector<int> a(N);
for (int i = 0; i < N; i++) scanf("%d", &a[i]);
auto p = solve(0, N - 1, a); // ⑥ 分治乘法
if (K < (int)p.size()) printf("%lld\n", p[K]);
else printf("0\n");
}
逐行解析:
- ①② 叶节点:。
- ③④⑤ 分治:先递归计算左半和右半的多项式,再用 NTT 卷积合并。
- ⑥ 总时间 。注意截断到长度 可以优化常数。
例题 4:字符串匹配加速——卷积找匹配
场景:文本串 (长度 )和模式串 (长度 ),只含小写字母。问 在 中出现了几次?允许通配符 ’?’。
数据范围:
分析:把字符映射为数值。匹配条件 。展开后是三个卷积之和:。每个卷积用 NTT 计算,总时间 。
代码:
#include <cstdio>
#include <cstring>
#include <vector>
#include <cmath>
using namespace std;
// 复用 convolution 函数(此处用浮点 FFT 或 NTT 均可)
// 为简洁起见用 double 版本
const double PI = acos(-1.0);
struct Complex {
double r, i;
Complex(double r = 0, double i = 0) : r(r), i(i) {}
Complex operator+(const Complex& o) const { return {r+o.r, i+o.i}; }
Complex operator-(const Complex& o) const { return {r-o.r, i-o.i}; }
Complex operator*(const Complex& o) const { return {r*o.r-i*o.i, r*o.i+i*o.r}; }
};
void fft(vector<Complex>& a, bool inverse) {
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
double ang = 2 * PI / len * (inverse ? -1 : 1); // ① FFT 角度
Complex wlen(cos(ang), sin(ang));
for (int i = 0; i < n; i += len) {
Complex w(1, 0);
for (int j = 0; j < len/2; j++) {
Complex u = a[i+j], v = a[i+j+len/2] * w;
a[i+j] = u + v;
a[i+j+len/2] = u - v;
w = w * wlen;
}
}
}
if (inverse) for (auto& x : a) { x.r /= n; x.i /= n; }
}
int main() {
char T[100009], P[100009];
scanf("%s%s", T, P);
int N = strlen(T), M = strlen(P);
// ② 构造三个数组用于卷积
int sz = 1;
while (sz < N + M) sz <<= 1;
vector<Complex> f1(sz), f2(sz), f3(sz), f4(sz), f5(sz);
for (int i = 0; i < N; i++) f1[i] = Complex(T[i] != '?' ? (T[i]-'a'+1) * (T[i]-'a'+1) : 0);
for (int i = 0; i < M; i++) f2[M-1-i] = Complex(P[i] != '?' ? (P[i]-'a'+1) : 0);
for (int i = 0; i < N; i++) f3[i] = Complex(T[i] != '?' ? (T[i]-'a'+1) : 0);
for (int i = 0; i < M; i++) f4[M-1-i] = Complex(P[i] != '?' ? (P[i]-'a'+1) * (P[i]-'a'+1) : 0);
for (int i = 0; i < N; i++) f5[i] = Complex(T[i] != '?' ? 1 : 0);
// ③ 三个卷积分别计算
// 卷积 1: sum T[i+j]^2 * P_rev[j]
fft(f1, false); fft(f2, false);
for (int i = 0; i < sz; i++) f1[i] = f1[i] * f2[i];
fft(f1, true);
// 卷积 2: -2 * sum T[i+j] * P_rev[j]
fft(f3, false); fft(f4, false);
for (int i = 0; i < sz; i++) f3[i] = f3[i] * f4[i];
fft(f3, true);
// 卷积 3: sum 1 * P_rev[j]^2 (不对, 应该是 sum P_j^2)
fft(f5, false); fft(f4, false);
for (int i = 0; i < sz; i++) f5[i] = f5[i] * f4[i];
fft(f5, true);
// ④ 汇总匹配分数
int count = 0;
for (int k = M-1; k < N; k++) {
double val = f1[k].r - 2 * f3[k].r + f5[k].r; // sum(T-P)^2
if (fabs(val) < 0.5) count++; // ⑤ 近似为 0 → 匹配
}
printf("%d\n", count);
}
逐行解析:
- ① FFT 角度:正向 ,逆向 。
- ② 把字符串匹配转化为数值卷积。反向模式串使卷积对齐。
- ③④⑤ 三个卷积分别计算 、、(此处简化),汇总得 。为 0 则完全匹配。
参考文献
拓展练习 — AtCoder Library Practice Contest
基础练习 — アルゴリズムと数学 演習問題集
系列索引
第零章 基础工具
第一章 搜索技术
第二章 数学基础
第三章 数据结构
- 03-01 栈、队列与单调栈
- 03-02 堆与优先队列
- 03-03 并查集
- 03-04 树状数组
- 03-05 线段树
- 03-06 懒标记线段树
- 03-07 Sparse Table 与倍增
- 03-08 字符串哈希
第四章 图论
- 04-01 图的遍历
- 04-02 最短路—Dijkstra 与 01-BFS
- 04-03 最短路—Bellman-Ford 与 Floyd
- 04-04 拓扑排序
- 04-05 最小生成树
- 04-06 强连通分量与 2-SAT
- 04-07 二分图与网络流
- 04-08 树上问题
第五章 动态规划
- 05-01 DP入门—状态与转移
- 05-02 背包问题族
- 05-03 LIS、LCS与编辑距离
- 05-04 区间DP
- 05-05 状态压缩DP
- 05-06 树形DP与数位DP
- 05-07 矩阵快速幂与线性递推
第六章 贪心
第七章 字符串
第八章 进阶