2023年12月,卡内基梅隆大学的Albert Gu和普林斯顿大学的Tri Dao在arXiv上发表了一篇论文,声称首次实现了"线性时间的Transformer级别性能"。这篇论文的标题很朴素——《Mamba: Linear-Time Sequence Modeling with Selective State Spaces》,但其展示的数据却引发了广泛关注:在百万级token长度上,Mamba的推理吞吐量达到同规模Transformer的5倍。
这不是一个简单的速度优化故事。Mamba代表的是一种完全不同的序列建模范式——状态空间模型(State Space Models, SSM)——向Transformer主导地位的真正挑战。
Transformer的阿喀琉斯之踵
要理解Mamba为何重要,首先需要理解Transformer的根本困境。
自注意力机制的核心操作是计算序列中每对token之间的关系。对于一个长度为 $L$ 的序列,这需要构建一个 $L \times L$ 的注意力矩阵,计算复杂度为 $O(L^2)$。这意味着序列长度翻倍,计算量会变成四倍;序列长度增加到100万,计算量会变成一万亿倍。
更麻烦的是,自回归推理时,每生成一个新token都需要重新计算整个序列的注意力——因为模型需要"回头看"所有之前的token。为了加速这个过程,Transformer会缓存所有历史token的Key-Value向量(KV Cache),但这又带来了 $O(L)$ 的内存占用。当序列长度达到数十万token时,显存很快就会成为瓶颈。
各种优化方案应运而生。Flash Attention通过分块计算减少显存访问;Sliding Window Attention限制每个token只能"看到"附近的token;Ring Attention将长序列分片到多个GPU上计算。但这些方法要么牺牲了完整的注意力能力,要么只是在推迟问题的爆发,而非从根本上解决问题。
真正的解决方案需要回答一个更根本的问题:序列模型真的需要 $O(L^2)$ 的复杂度吗?
状态空间模型:从控制理论到深度学习
状态空间模型(SSM)提供了一个完全不同的视角。它源自控制理论,用于描述动态系统如何随时间演化。
一个连续时间的状态空间模型可以用两个方程描述:
$$h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t)$$$$y(t) = \mathbf{C}h(t) + \mathbf{D}x(t)$$第一个是状态方程,描述隐藏状态 $h(t)$ 如何从上一状态和当前输入 $x(t)$ 演化而来。$\mathbf{A}$ 是状态转移矩阵,控制信息如何随时间保留或衰减;$\mathbf{B}$ 是输入矩阵,控制新信息如何进入状态。
第二个是输出方程,描述如何从当前状态和输入产生输出 $y(t)$。$\mathbf{C}$ 将状态映射到输出,$\mathbf{D}$ 则是一个类似于skip connection的直通路径。
这个框架的精妙之处在于:隐藏状态 $h(t)$ 理论上可以压缩整个历史的所有信息。如果状态维度足够高,并且状态转移矩阵 $\mathbf{A}$ 设计得当,那么只需要知道当前状态和当前输入,就可以预测未来——不需要回头看任何历史token。
连续到离散:零阶保持
深度学习处理的是离散序列(如文本token),而非连续信号。因此需要将连续SSM离散化。
最常用的方法是零阶保持(Zero-Order Hold, ZOH):假设在每个时间步 $\Delta$ 内,输入信号保持恒定。这给出离散化公式:
$$\bar{\mathbf{A}} = \exp(\Delta \mathbf{A})$$$$\bar{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}$$离散化后的SSM可以写成:
$$h_t = \bar{\mathbf{A}} h_{t-1} + \bar{\mathbf{B}} x_t$$$$y_t = \mathbf{C} h_t$$这看起来非常像RNN!关键区别在于:传统RNN的隐藏状态维度通常与输入维度相同或接近,而SSM的隐藏状态维度 $N$ 可以远大于输入维度 $D$,从而提供更强的表达能力。
三种计算模式
SSM有一个独特的优势:它可以以三种等价方式计算,各有优劣。
循环模式:直接按时间步递归计算。推理时只需要常量内存(存储当前状态),每步计算复杂度为 $O(N)$。这解决了Transformer的 $O(L)$ 内存问题。
卷积模式:对于线性时不变(LTI)系统,SSM可以表示为一个卷积操作。这意味着可以用FFT高效计算整个序列,实现训练时的并行化——解决了RNN无法并行训练的问题。
连续模式:保持数学上的连续形式,提供理论优雅性和分辨率不变性。
聪明的读者可能已经发现了一个陷阱:卷积模式要求 $\mathbf{A}$、$\mathbf{B}$、$\mathbf{C}$ 对所有时间步保持不变(线性时不变性),这意味着模型无法根据输入内容动态调整行为。
HiPPO与S4:长程依赖的数学基础
早期的SSM在长序列上表现不佳,核心问题是状态转移矩阵 $\mathbf{A}$ 如何初始化。随机初始化的 $\mathbf{A}$ 很容易导致梯度消失或爆炸,使得模型无法学习长程依赖。
2020年,Gu等人提出了HiPPO(High-order Polynomial Projection Operators) 理论,给出了一个优雅的答案。
HiPPO的核心洞察是:状态可以被视为历史输入的多项式逼近系数。如果用勒让德多项式作为基函数,那么状态转移矩阵 $\mathbf{A}$ 可以设计为:
$$\mathbf{A}_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases}$$这个矩阵有一个重要性质:它能够更好地保留"近期"信息,同时逐渐衰减"远期"信息——这正是处理长序列所需要的。
HiPPO矩阵的Python实现:
def get_hippo_A(N: int) -> np.ndarray:
A = np.zeros((N, N))
for n in range(N):
for k in range(N):
if n > k:
A[n, k] = (2*n+1)**0.5 * (2*k+1)**0.5
elif n == k:
A[n, k] = n + 1
return A
基于HiPPO理论,2021年提出的S4(Structured State Space Sequence) 模型在Long Range Arena基准测试上首次超越了Transformer。S4的关键创新是将HiPPO矩阵分解为对角加低秩(Diagonal Plus Low-Rank, DPLR) 形式,使得卷积核可以通过FFT高效计算。
然而,S4仍然受限于LTI假设——它无法根据输入内容选择性地记住或遗忘信息。
Mamba的核心创新:选择性状态空间模型
Mamba的突破来自于一个简单但深刻的问题:如果让SSM参数($\mathbf{B}$、$\mathbf{C}$、$\Delta$)依赖于输入,会发生什么?
选择性机制
传统SSM中,$\mathbf{B}$、$\mathbf{C}$ 和步长 $\Delta$ 是固定的参数。Mamba将它们改为输入的函数:
$$\mathbf{B}_t = \text{Linear}_N(x_t)$$$$\mathbf{C}_t = \text{Linear}_N(x_t)$$$$\Delta_t = \text{Broadcast}_D(\text{Linear}_1(x_t))$$这意味着每个时间步都有不同的"记忆写入门"($\mathbf{B}$)和"记忆读取门"($\mathbf{C}$)。步长 $\Delta$ 则控制模型应该"聚焦"于当前输入还是"忽略"它——较大的 $\Delta$ 意味着更关注当前输入并重置状态,较小的 $\Delta$ 则意味着忽略当前输入并保持状态。
论文用两个合成任务展示了选择性的重要性:
选择性复制任务:输入中散布着需要记住的"关键"token和需要忽略的"噪声"token。LTI模型无法区分它们,而选择性SSM可以精准地只记住关键信息。
Induction Heads任务:模型需要根据上下文"联想"出答案。例如,如果看到"Harry Potter",下次看到"Harry"时应该预测"Potter"。这需要根据内容动态决定何时调用历史信息。
硬件感知算法:让选择性成为可能
引入选择性后,SSM失去了LTI性质,无法再使用卷积模式进行高效训练。但这恰恰暴露了一个更深层的机会。
传统观点认为卷积模式比循环模式更高效,因为卷积可以并行计算。然而,这个"高效"是针对FLOPs而言的。在现代GPU上,真正的瓶颈往往不是计算能力,而是显存带宽——数据在高速SRAM和低速HBM之间传输的速度。
Mamba的硬件感知算法采用了三个关键技术:
内核融合(Kernel Fusion):将离散化、状态更新和输出计算合并到一个CUDA内核中。这样,中间状态只在快速的SRAM中存在,无需写入慢速的HBM。
并行扫描(Parallel Scan):尽管递归计算看起来是串行的,但实际上可以通过分治策略并行化。将序列分成多段,分别计算每段的"前缀和",然后合并结果。
重计算(Recomputation):前向传播时不保存中间状态,而是在反向传播时重新计算。虽然增加了计算量,但避免了大量显存访问,在GPU上反而更快。
结果是:Mamba的训练速度比之前最优的SSM实现快3倍,推理吞吐量比同规模Transformer高5倍。
Mamba-2与状态空间对偶性
2024年5月,Mamba的原作者发布了Mamba-2,提出了一个更深刻的理论框架——状态空间对偶性(State Space Duality, SSD)。
核心洞察是:选择性SSM和注意力机制之间存在数学上的对偶关系。
具体来说,对于标量结构的SSM($\mathbf{A}$ 是标量乘以单位矩阵),其输出可以写成一个类似注意力的矩阵乘法:
$$\mathbf{M} = \mathbf{L} \circ \mathbf{C}\mathbf{B}^\top$$其中 $\mathbf{L}$ 是一个下三角矩阵,$\circ$ 是逐元素乘法。如果所有 $a_t = 1$,$\mathbf{L}$ 就是标准的因果掩码,而 $\mathbf{C}\mathbf{B}^\top$ 对应 $\mathbf{Q}\mathbf{K}^\top$——这正是因果线性注意力的形式!
这个对偶性带来了两个重要好处:
-
训练加速:可以利用高度优化的矩阵乘法来计算SSM,而不是依赖专门的CUDA内核。
-
理论统一:注意力机制和状态空间模型被置于同一个数学框架下,为未来的架构创新提供了清晰的路线图。
Mamba-2还支持更大的状态维度(从N=16增加到N=64甚至256),同时训练速度进一步提升。在多查询联想回忆任务上,Mamba-2的表现显著优于Mamba-1。
性能对比:数字说话
在语言建模基准测试上,Mamba展现出了与Transformer相当甚至更好的扩展性。
| 模型 | 参数量 | Pile PPL | LAMBADA | HellaSwag | 平均分 |
|---|---|---|---|---|---|
| Pythia-2.8B | 2.8B | 6.73 | 64.7 | 74.0 | 59.1 |
| RWKV-3B | 3B | 7.00 | 63.9 | 73.7 | 59.6 |
| Mamba-2.8B | 2.8B | 6.22 | 69.2 | 75.2 | 63.3 |
| Pythia-6.9B | 6.9B | 6.51 | 67.1 | 75.2 | 61.7 |
值得注意的是,Mamba-2.8B的性能甚至超过了参数量两倍于它的Pythia-6.9B。
在推理效率上,差距更为明显。由于不需要维护KV Cache,Mamba可以使用更大的批处理大小。在A100 GPU上,Mamba-1.4B的推理吞吐量比Transformer-1.3B高出4-5倍。
更令人印象深刻的是长上下文能力。在DNA序列建模任务上,Mamba能够利用长达100万token的上下文,性能随上下文长度增加而持续提升;而对比模型HyenaDNA在长上下文上反而性能下降——因为它无法选择性地过滤噪声信息。
局限性与未来方向
尽管Mamba展现了令人振奋的潜力,但它并非没有局限。
生态系统成熟度:Transformer拥有庞大的生态系统,包括预训练模型、微调方法、量化技术、推理优化等。Mamba还在早期阶段,这些工具和最佳实践仍在发展中。
缩放验证:Mamba的论文主要在较小规模(<3B参数)上验证。更大规模的模型是否保持同样的优势,仍需进一步研究。
连续vs离散谱系:SSM最初是为连续信号设计的,在音频等感知任务上表现优异。选择性机制帮助它处理离散文本,但在某些纯连续信号任务上可能反而不及原始的非选择性SSM。
混合架构的探索:也许最优的方案不是"纯Mamba"或"纯Transformer",而是两者的结合。例如,可以在近期token上使用注意力(高精度短期记忆),在远期token上使用SSM(高效长期记忆)。
结语:序列建模的新范式
Mamba的意义不仅在于技术性能的提升,更在于它打开了序列建模的新范式。
长期以来,Transformer被视为"唯一的选择"。每一个试图挑战它的架构——从线性注意力到高效Transformers——都以某种方式牺牲了性能或表达能力。Mamba第一次证明了:线性复杂度和Transformer级别性能可以兼得。
更重要的是,状态空间对偶性揭示了注意力机制和状态空间模型的深层联系。也许它们不是竞争关系,而是同一枚硬币的两面——在不同的计算约束和应用场景下各有优势。
当序列长度从几千扩展到几百万,当模型需要真正的"长期记忆",当推理效率成为关键瓶颈——在这些场景下,Mamba和它的后继者们将有机会重新定义序列建模的可能性。
参考文献
- Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv:2312.00752.
- Gu, A., Goel, K., & Ré, C. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. ICLR 2022.
- Gu, A., Dao, T., et al. (2020). HiPPO: Recurrent Memory with Optimal Polynomial Projections. NeurIPS 2020.
- Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. arXiv:2405.21060.
- Smith, J. T. H., Warrington, A., & Linderman, S. W. (2023). Simplified State Space Layers for Sequence Modeling. ICLR 2023.
- Katharopoulos, A., et al. (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML 2020.
- Olsson, C., et al. (2022). In-context Learning and Induction Heads. Transformer Circuits Thread.