2019年11月,Noam Shazeer在arXiv上发表了一篇标题颇为大胆的论文——《Fast Transformer Decoding: One Write-Head is All You Need》。这篇仅6页的论文提出了一个看似简单的问题:Transformer解码时,我们真的需要那么多Key和Value头吗?

答案是否定的。通过让所有Query头共享同一组Key和Value,推理速度可以提升数倍,而模型质量的损失却微乎其微。这就是**多查询注意力(Multi-Query Attention,MQA)**的诞生。

一个反直觉的现象:推理比训练慢

在深度学习的世界里,训练通常被认为是更"昂贵"的操作——需要反向传播、需要存储梯度、需要更新数十亿参数。推理不过是简单的前向传播,应该很快才对。

但Transformer的自回归解码打破了这个直觉。

graph LR
    subgraph "训练阶段(并行)"
        T1["Token 1"] --> T_ALL["同时计算所有注意力"]
        T2["Token 2"] --> T_ALL
        T3["Token 3"] --> T_ALL
        T4["Token 4"] --> T_ALL
    end
    
    subgraph "推理阶段(串行)"
        I1["生成 Token 1"] --> I2["生成 Token 2"]
        I2 --> I3["生成 Token 3"]
        I3 --> I4["生成 Token 4"]
    end
    
    style T_ALL fill:#e1f5fe
    style I1 fill:#fff3e0
    style I2 fill:#fff3e0
    style I3 fill:#fff3e0
    style I4 fill:#fff3e0

在训练阶段,我们可以并行处理整个序列。假设输入是"今天天气不错",模型可以同时计算每个位置对之前所有位置的注意力——第1个词看自己,第2个词看前两个,第3个词看前三个。这种并行化让GPU的计算能力得到充分利用。

推理则完全不同。生成第1个词时,模型只知道输入;生成第2个词时,模型才知道第1个词是什么;生成第3个词时,模型才能看到前两个词。每生成一个新词,都需要重新计算注意力,而且计算量会随着序列长度线性增长。

更关键的是,每生成一个token,都需要从GPU的高带宽内存(HBM)中加载整个KV缓存——存储了之前所有token的Key和Value向量。这个缓存有多大?

graph TD
    A["KV Cache大小计算"] --> B["层数 L"]
    A --> C["隐藏维度 d"]
    A --> D["序列长度 s"]
    A --> E["字节数 bytes"]
    
    B --> F["KV = 2 × L × d × s × bytes"]
    C --> F
    D --> F
    E --> F
    
    F --> G["70B模型示例<br/>L=32, d=8192, s=4096"]
    G --> H["KV = 2 × 32 × 8192 × 4096 × 2<br/>= 4.3 GB"]
    
    style H fill:#ffcdd2

对于一个70B参数、32层、隐藏维度8192、序列长度4096的模型:

$$\text{KV Cache} = 2 \times L \times d \times s \times \text{bytes} = 2 \times 32 \times 8192 \times 4096 \times 2 = 4.3 \text{ GB}$$

这只是存储一个请求的KV缓存。如果同时处理16个请求,仅KV缓存就需要约70GB——超过了A100 80GB显卡除去模型权重后的可用空间。

内存带宽:被忽视的瓶颈

现代GPU的发展呈现出一个明显的趋势:计算能力的增长远快于内存带宽的增长。

graph LR
    subgraph "NVIDIA GPU演进"
        A["V100<br/>FP32: 15.7 TFLOPS<br/>带宽: 900 GB/s"] 
        B["A100<br/>FP32: 19.5 TFLOPS<br/>带宽: 2039 GB/s"]
        C["H100<br/>FP32: 67 TFLOPS<br/>带宽: 3352 GB/s"]
    end
    
    A --> B --> C
    
    D["计算能力增长: 4.3x"] 
    E["内存带宽增长: 3.7x"]
    
    style A fill:#e3f2fd
    style B fill:#bbdefb
    style C fill:#90caf9
    style D fill:#fff3e0
    style E fill:#ffccbc

以NVIDIA A100为例,其FP32计算能力达到19.5 TFLOPS,而内存带宽约为2 TB/s。这意味着,如果每个浮点运算需要从内存读取4字节输入并写入4字节输出,那么达到峰值计算能力需要:

$$\frac{19.5 \times 10^{12} \times 8}{2 \times 10^{12}} = 78 \text{ FLOP/byte}$$

也就是说,每从内存读取1字节数据,需要执行78次浮点运算,才能让计算单元"忙起来"。这被称为算术强度(Arithmetic Intensity)

graph TD
    subgraph "Roofline模型"
        A["内存带宽受限区域"] -->|算术强度增加| B["计算能力受限区域"]
        A --> C["斜率 = 内存带宽"]
        B --> D["峰值 = 计算能力"]
    end
    
    E["低算术强度操作<br/>(如自回归解码)"] -.-> A
    F["高算术强度操作<br/>(如矩阵乘法)"] -.-> B
    
    G["MQA的目标:<br/>提高算术强度"] --> F
    
    style A fill:#ffcdd2
    style B fill:#c8e6c9
    style E fill:#ffecb3
    style G fill:#b3e5fc

Roofline模型清晰地描述了这一关系:当算术强度低于某个阈值时,程序受限于内存带宽;当算术强度高于这个阈值时,程序才受限于计算能力。

问题在于,Transformer的自回归解码恰恰是一个算术强度极低的操作。每生成一个token,需要加载整个KV缓存,却只执行一次注意力计算。随着序列变长,数据量线性增长,但计算量的增长相对缓慢。

这就是为什么即使你有强大的计算能力,推理速度依然受限于内存带宽——你的GPU在等待数据从内存"流"过来,而不是在疯狂计算。

MQA的核心思想:用共享换取效率

面对内存带宽瓶颈,Shazeer提出了一个简单的解决方案:减少需要加载的数据量。

在标准的多头注意力(MHA)中,每个注意力头都有独立的Query、Key和Value投影。假设有32个头,那么对于每个token,需要存储32组Key和32组Value。

MQA的核心改动极其简单:保留多个Query头,但让所有头共享同一组Key和Value。

graph TD
    subgraph "MHA(多头注意力)"
        MHA_Q1["Query Head 1"] --> MHA_A1["Attention 1"]
        MHA_K1["Key Head 1"] --> MHA_A1
        MHA_V1["Value Head 1"] --> MHA_A1
        
        MHA_Q2["Query Head 2"] --> MHA_A2["Attention 2"]
        MHA_K2["Key Head 2"] --> MHA_A2
        MHA_V2["Value Head 2"] --> MHA_A2
        
        MHA_Qn["Query Head n"] --> MHA_An["Attention n"]
        MHA_Kn["Key Head n"] --> MHA_An
        MHA_Vn["Value Head n"] --> MHA_An
    end
    
    style MHA_K1 fill:#ffcdd2
    style MHA_K2 fill:#ffcdd2
    style MHA_Kn fill:#ffcdd2
    style MHA_V1 fill:#ffcdd2
    style MHA_V2 fill:#ffcdd2
    style MHA_Vn fill:#ffcdd2
graph TD
    subgraph "MQA(多查询注意力)"
        MQA_Q1["Query Head 1"] --> MQA_A1["Attention 1"]
        MQA_Q2["Query Head 2"] --> MQA_A2["Attention 2"]
        MQA_Qn["Query Head n"] --> MQA_An["Attention n"]
        
        MQA_K["共享 Key"] --> MQA_A1
        MQA_K --> MQA_A2
        MQA_K --> MQA_An
        
        MQA_V["共享 Value"] --> MQA_A1
        MQA_V --> MQA_A2
        MQA_V --> MQA_An
    end
    
    style MQA_K fill:#c8e6c9
    style MQA_V fill:#c8e6c9

这个改动的影响是立竿见影的。对于32头的模型,KV缓存的大小减少了32倍。更重要的是,推理时需要从内存加载的数据量也减少了32倍。

graph LR
    subgraph "KV缓存大小对比"
        MHA["MHA<br/>32头<br/>512 KB/token"]
        MQA["MQA<br/>1组KV<br/>16 KB/token"]
    end
    
    MHA -->|"减少 32 倍"| MQA
    
    style MHA fill:#ffcdd2
    style MQA fill:#c8e6c9

这意味着什么?让我们用一个具体例子说明。Falcon 7B模型有32个注意力头,如果使用MHA,每个token的KV缓存需要:

$$32 \times 2 \times 4096 \times 2 = 524,288 \text{ bytes} \approx 512 \text{ KB}$$

而使用MQA后:

$$1 \times 2 \times 4096 \times 2 = 16,384 \text{ bytes} \approx 16 \text{ KB}$$

减少了32倍。

为什么这能加速推理?

MQA加速推理的机制可以分解为三个层面:

graph TD
    A["MQA加速机制"] --> B["第一层:内存占用减少"]
    A --> C["第二层:内存带宽压力降低"]
    A --> D["第三层:算术强度提升"]
    
    B --> B1["更小的KV缓存<br/>→ 更多并发请求"]
    C --> C1["减少数据加载量<br/>→ 内存带宽消耗降低"]
    D --> D1["数据重用率提高<br/>→ 从带宽受限转为计算受限"]
    
    B1 --> E["整体效果:<br/>推理速度提升数倍"]
    C1 --> E
    D1 --> E
    
    style D1 fill:#c8e6c9
    style E fill:#b3e5fc

第一层:内存占用减少

更小的KV缓存意味着可以在同样的GPU内存中容纳更多的并发请求。Fireworks AI的测试显示,在A100上使用Falcon 40B模型,MQA可以将最大批次大小提升16倍。在显存更小的A6000上,MQA使长序列处理从"不可行"变成了"可行"。

第二层:内存带宽压力降低

推理时,每个token生成都需要加载整个KV缓存。减少KV缓存的大小,直接减少了内存带宽的消耗。当KV缓存从512KB降到16KB时,内存带宽的使用也相应减少了32倍。

第三层:算术强度提升

这是最关键的一点。MQA通过共享Key和Value,实际上增加了数据的重用率。同一个Key-Value对被32个Query头使用,这意味着每次从内存加载的数据被执行了更多次计算。

算术强度从原来的每字节数据对应很少次运算,变成了每字节数据对应更多次运算。当算术强度超过GPU的"脊点"(Ridge Point)时,程序从内存带宽受限转变为计算能力受限——这正是我们想要的。

实际性能数据

Fireworks AI的基准测试提供了具体的数据:

graph TD
    subgraph "延迟优化结果"
        L1["A100 + Falcon 40B<br/>延迟降低约 30%"]
        L2["A6000 + Falcon 7B<br/>延迟降低约 35%"]
        L3["A6000 + Falcon 40B<br/>延迟降低约 40%"]
    end
    
    subgraph "吞吐优化结果"
        T1["A100 + Falcon 40B<br/>吞吐提升约 3x"]
        T2["A6000 + Falcon 7B<br/>吞吐提升 > 10x"]
    end
    
    style L1 fill:#c8e6c9
    style L2 fill:#c8e6c9
    style L3 fill:#c8e6c9
    style T1 fill:#b3e5fc
    style T2 fill:#b3e5fc

特别值得注意的是A6000的场景。由于显存有限,不使用MQA时,Falcon 40B甚至无法处理较长的序列。MQA不仅提升了速度,更扩展了模型的使用边界。

代价:不可避免的性能损失

没有免费的午餐。MQA的效率提升是有代价的。

graph LR
    A["MQA的权衡"] --> B["效率提升<br/>速度 1.5-2x<br/>吞吐 3-11x"]
    A --> C["代价<br/>质量下降 ~2%<br/>训练不稳定"]
    
    D["多头注意力的设计初衷"] --> E["不同头关注不同子空间"]
    E --> F["语法结构 / 语义关联 / 长距离依赖"]
    
    G["共享KV的局限"] --> H["表达能力被压缩到单一子空间"]
    
    style B fill:#c8e6c9
    style C fill:#ffcdd2
    style H fill:#fff3e0

多篇研究指出,MQA会导致模型质量下降约2%。这在某些场景下可能是可接受的,但在追求极致性能的场景中,这可能是一个不可忽视的权衡。

为什么共享Key和Value会影响模型能力?关键在于多头注意力的设计初衷。

原始Transformer论文的作者写道:“多头注意力使模型能够同时关注来自不同位置的不同表示子空间的信息。“每个注意力头学习关注输入的不同方面——有的头关注语法结构,有的头关注语义关联,有的头关注长距离依赖。

当所有头被迫共享同一组Key和Value时,这种多样性受到了限制。虽然Query头仍然可以学习不同的"提问方式”,但Key和Value的表达能力被压缩到了一个子空间。这就像让32个人用同一本字典来查找不同的词——虽然他们可以查找不同的内容,但字典本身的内容是固定的。

更严重的问题是训练不稳定性。研究显示,使用MQA训练模型时,如果采用与MHA相同的学习率,容易出现训练不稳定。建议的做法是将学习率降低40-50%,这无疑延长了训练时间。

从极端到折中:GQA的诞生

MQA的问题促使研究者在2023年提出了分组查询注意力(Grouped-Query Attention,GQA)

GQA的思想很简单:与其走极端(要么每个头独立的KV,要么所有头共享一组KV),不如找一个中间点。将Query头分成若干组,每组共享一组Key和Value。

graph TD
    subgraph "MHA(h组KV)"
        MHA1["Head 1: Q1, K1, V1"]
        MHA2["Head 2: Q2, K2, V2"]
        MHAh["Head h: Qh, Kh, Vh"]
    end
    
    subgraph "GQA(g组KV, g < h)"
        GQA1["Group 1: Q1-Q4 共享 K1, V1"]
        GQAg["Group g: Qh-3-Qh 共享 Kg, Vg"]
    end
    
    subgraph "MQA(1组KV)"
        MQA1["All Q heads share K, V"]
    end
    
    MHA1 --> GQA1 --> MQA1
    
    style MHA1 fill:#ffcdd2
    style GQA1 fill:#fff3e0
    style MQA1 fill:#c8e6c9

假设模型有32个Query头,GQA可以使用8组KV(每个KV被4个Query头共享)。这样KV缓存减少了4倍,同时保留了比MQA更多的表达能力。

GQA的实验结果令人鼓舞:使用8组KV的GQA,在推理速度上接近MQA,而在模型质量上几乎与MHA持平。这正是人们期待的"两全其美"的方案。

实际应用:谁在使用MQA?

MQA和GQA已经在大模型生态中广泛应用:

timeline
    title MQA/GQA应用时间线
    2019 : PaLM 使用 MQA
    2023.05 : Falcon 发布<br/>使用改进版 MQA/GQA
    2023.07 : Llama 2 发布<br/>使用 GQA-8
    2024 : Llama 3, Mistral 等<br/>广泛采用 GQA

PaLM系列:Google的PaLM模型采用了MQA架构,这是最早大规模使用MQA的模型之一。

Falcon系列:TII的Falcon模型使用了一种改进的MQA。在Falcon 40B中,num_heads=128,num_kv_heads=8——相当于GQA-8。这反映了从纯MQA向GQA过渡的趋势。

Llama 2/3:Meta的Llama系列使用了GQA架构。Llama 2 70B使用GQA-8,在保持推理效率的同时获得了更好的模型质量。

StarCoder:代码生成模型StarCoder同样采用了MQA架构,这对于需要长上下文的代码生成任务尤其重要。

值得注意的是,早期使用MQA的模型后来大多转向了GQA。这并非因为MQA的效果不好,而是因为GQA提供了更灵活的权衡空间——你可以根据具体需求在速度和质量之间选择最佳平衡点。

工程实现的细节

MQA的实现远比概念复杂。关键挑战在于:现有的注意力计算内核(如FlashAttention、PagedAttention)通常假设Key和Value的形状与Query相同。

graph TD
    A["MQA实现挑战"] --> B["现有内核假设Q/K/V形状相同"]
    B --> C["方案1: 预先广播KV"]
    B --> D["方案2: 内核内处理广播"]
    
    C --> C1["问题: 额外内存开销<br/>抵消MQA优势"]
    D --> D1["优势: 保持MQA优势<br/>需要修改内核代码"]
    
    D1 --> E["Fireworks AI 方案:<br/>扩展自定义注意力内核"]
    
    style C1 fill:#ffcdd2
    style D1 fill:#c8e6c9

一种直观的实现方式是在注意力计算前"广播"Key和Value——将单一的KV复制多份以匹配Query头的数量。但这样做会带来额外的内存开销和复制操作,抵消了MQA的部分收益。

更优的做法是在注意力内核内部处理这种广播。这意味着需要修改或重新实现注意力计算的核心代码。Fireworks AI在他们的实现中扩展了自定义注意力内核,使其能够内联处理KV广播,从而在保持MQA优势的同时避免了额外的内存开销。

另一个工程细节是多GPU场景下的处理。当模型需要分布在多个GPU上时(如张量并行),MQA的KV共享机制需要特殊处理。Falcon论文提出的方案是将KV头数设置为GPU数量——每个GPU维护自己的KV副本,避免了GPU间的通信开销。

MQA的局限性

尽管MQA带来了显著的推理加速,但它并非万能药。

graph TD
    A["MQA局限性"] --> B["需要重新训练"]
    A --> C["解码阶段收益更明显"]
    A --> D["其他优化技术的竞争"]
    A --> E["更适合解码器架构"]
    
    B --> B1["无法直接转换预训练MHA模型"]
    C --> C1["预填充阶段本身可并行化"]
    D --> D1["FlashAttention/PagedAttention<br/>解决部分问题"]
    E --> E1["编码器架构MQA价值有限"]
    
    style B fill:#fff3e0
    style C fill:#fff3e0
    style D fill:#fff3e0
    style E fill:#fff3e0

首先,MQA需要从头训练或进行大量的微调。你无法直接将一个预训练好的MHA模型转换为MQA模型而不损失性能。虽然GQA论文提出了"uptraining"方案——从MHA检查点开始,仅用原训练计算量的5%就能转换为GQA——但这仍然是一个不可忽视的成本。

其次,MQA的收益主要体现在解码阶段。对于预填充(prefill)阶段——处理提示词的初始计算——MQA的优势并不明显,因为预填充本身就可以并行化。

第三,随着其他推理优化技术的发展,MQA的相对优势可能会被削弱。比如FlashAttention通过优化内存访问模式,已经显著提升了注意力计算的效率;PagedAttention通过分页管理KV缓存,减少了内存碎片。这些技术与MQA可以叠加使用,但它们也解决了一部分MQA原本要解决的问题。

最后,MQA更适合解码器架构。对于编码器-解码器架构或纯编码器架构(如BERT),MQA的价值有限,因为这些架构本身就可以充分利用并行化。

从MQA到MLA:效率优化的持续演进

MQA的故事并没有结束。2024年,DeepSeek团队提出了多头潜在注意力(Multi-Head Latent Attention,MLA),这是注意力效率优化的新里程碑。

graph LR
    A["KV优化演进"] --> B["MQA: 硬压缩<br/>减少KV头数"]
    A --> C["GQA: 折中方案<br/>分组共享KV"]
    A --> D["MLA: 软压缩<br/>低秩投影"]
    
    B --> E["效率提升<br/>质量下降 ~2%"]
    C --> F["效率提升<br/>质量接近MHA"]
    D --> G["效率提升 93%<br/>质量超越MHA"]
    
    style G fill:#c8e6c9

MLA的思路与MQA/GQA不同:它不是通过减少KV头的数量来降低内存占用,而是通过低秩压缩来实现。MLA将Key和Value投影到一个低维的潜在空间进行存储,在计算时再解压缩。

这种方法的精妙之处在于:它不仅减少了KV缓存的大小(在DeepSeek-V2中减少了约93%),而且据报告在模型质量上甚至超过了原始MHA。这是一个令人惊讶的结果——通过更高效的表示方式,我们不仅没有损失能力,反而获得了提升。

这揭示了一个深层的洞察:KV缓存的大小并非必须与模型的表示能力成正比。MQA/GQA是通过"硬"压缩(直接减少头数)来换取效率,而MLA则是通过"软"压缩(低维投影)来实现效率提升。

如何选择:MHA vs MQA vs GQA

在实际项目中,如何选择注意力架构?以下是一个简化的决策框架:

graph TD
    A["注意力架构选择"] --> B{"模型质量是否<br/>为最优先目标?"}
    
    B -->|是| C["选择 MHA"]
    B -->|否| D{"内存/速度是否<br/>为关键瓶颈?"}
    
    D -->|否| C
    D -->|是| E{"能否接受<br/>~2%质量下降?"}
    
    E -->|能| F["选择 MQA"]
    E -->|不能| G["选择 GQA"]
    
    H["从现有MHA迁移"] --> I["选择 GQA<br/>通过uptraining"]
    
    style C fill:#e3f2fd
    style F fill:#c8e6c9
    style G fill:#fff3e0

选择MHA的场景:

  • 模型质量是最优先的目标
  • 推理速度和内存占用不是瓶颈
  • 训练资源充足,可以进行充分的超参数调优

选择MQA的场景:

  • 推理延迟是核心指标
  • 内存资源受限(如在边缘设备或小显存GPU上部署)
  • 可以接受约2%的质量下降
  • 有足够的训练资源从头训练模型

选择GQA的场景:

  • 需要在质量和效率之间取得平衡
  • 希望从现有的MHA模型迁移(通过uptraining)
  • 希望灵活调整组数以适应不同的硬件约束

大多数现代大模型(如Llama 2/3、Mistral等)选择了GQA,这反映了一个行业共识:GQA提供了最佳的权衡空间。

总结

MQA的故事是深度学习领域工程创新的一个缩影。

它源于一个简单但深刻的观察:Transformer推理的瓶颈不是计算,而是内存带宽。这个观察催生了一个简单的解决方案:减少需要加载的数据量。

但简单并不意味着浅显。MQA背后的roofline模型、算术强度、数据重用等概念,揭示了一个更普遍的工程原则:在异构计算系统中,理解硬件特性并据此优化算法,往往比纯粹的算法创新更重要。

从MHA到MQA,再到GQA和MLA,我们看到的是一条清晰的演进路径:研究者们在不断地寻找效率与能力的最佳平衡点。这个平衡点可能因应用场景而异,也可能因硬件发展而变,但优化的方向始终明确——让计算单元不等待数据,让内存带宽不成为瓶颈。

当你在部署一个大模型服务时,推理速度慢可能不是模型"太大”,而是KV缓存"太重"。这时候,MQA或GQA可能就是答案——用一点点质量,换取数倍的效率。


参考资料

  1. Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv preprint arXiv:1911.02150.

  2. Ainslie, J., et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv preprint arXiv:2305.13245. EMNLP 2023.

  3. Vaswani, A., et al. (2017). Attention Is All You Need. Advances in Neural Information Processing Systems.

  4. Almazrouei, E., et al. (2023). The Falcon Series of Open Language Models. arXiv preprint arXiv:2311.16867.

  5. Williams, S., et al. (2009). Roofline: An Insightful Visual Performance Model for Multicore Architectures. Communications of the ACM.

  6. NVIDIA. A100 Tensor Core GPU Architecture Whitepaper.

  7. NVIDIA. H100 Tensor Core GPU Architecture Whitepaper.

  8. Fireworks AI. Multi-Query Attention is All You Need. Technical Blog, 2023.

  9. Kwon, W., et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. SOSP 2023.

  10. Dao, T., et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.

  11. Liu, A., et al. (2024). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model. arXiv preprint arXiv:2405.04434.

  12. Ge, S., et al. (2023). Model Tells You What to Discard: Adaptive KV Cache Compression for LLMs. ICLR 2024.