2019年,Noam Shazeer在Google发表了一篇仅4页的论文,标题是《Fast Transformer Decoding: One Write-Head is All You Need》。这篇论文提出了一个激进的方案:让所有注意力头共享同一组Key和Value。结果是KV Cache缩小了几十倍,推理速度提升了10倍以上。

但这个方案有一个致命缺陷——模型质量下降明显,训练也不稳定。

四年后,Google团队在2023年提出了GQA(Grouped-Query Attention),找到了MHA和MQA之间的黄金平衡点。这个技术被Llama 2、Llama 3、Mistral等几乎所有现代开源大模型采用。它究竟解决了什么问题?为什么能让大模型推理既快又准?

KV Cache:大模型推理的隐形内存杀手

理解GQA的价值,首先要理解Transformer推理时的内存瓶颈。

在自回归生成中,每生成一个新token,模型都需要访问之前所有token的Key和Value向量。为了避免重复计算,这些向量被缓存起来,这就是KV Cache。

以Llama 2 70B为例:

  • 隐藏层维度:$d = 8192$
  • 注意力头数:$n_h = 64$
  • 层数:$L = 80$
  • KV头数:$n_{kv} = 8$(GQA配置)

对于单个token,每层的KV Cache大小为:

$$\text{KV Cache per layer} = 2 \times n_{kv} \times d_h \times \text{bytes}$$

其中 $d_h = d / n_h = 128$ 是每个头的维度,系数2表示Key和Value各一份。使用FP16(2字节),每层每个token的KV Cache为:

$$2 \times 8 \times 128 \times 2 = 4\text{KB}$$

对于80层模型,每个token总共需要 $4\text{KB} \times 80 = 320\text{KB}$。

这看起来不大,但考虑实际场景:序列长度4096,批次大小8:

$$320\text{KB} \times 4096 \times 8 = 10.5\text{GB}$$

仅仅KV Cache就占用了10GB以上的显存。如果使用标准的MHA(64个KV头),这个数字会膨胀到:

$$320\text{KB} \times \frac{64}{8} = 2.56\text{GB} \times 4096 \times 8 \approx 84\text{GB}$$

这就是为什么Llama 2 70B必须使用GQA——标准的MHA根本无法在单张80GB A100上运行长序列推理。

MQA:极端压缩的代价

2019年,Shazeer提出MQA时,思路非常直接:既然KV Cache的瓶颈在于头数太多,为什么不把所有头合并成一个?

标准MHA中,每个注意力头有独立的Key、Value投影:

$$\mathbf{k}_i^{(s)} = \mathbf{x}_i \mathbf{W}_k^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{x}_i \mathbf{W}_v^{(s)}$$

其中 $s = 1, \ldots, n_h$ 表示不同的头。

MQA让所有头共享同一组Key和Value:

$$\mathbf{k}_i = \mathbf{x}_i \mathbf{W}_k, \quad \mathbf{v}_i = \mathbf{x}_i \mathbf{W}_v$$

这样,KV Cache从头数 $n_h$ 降低到1,内存占用减少 $n_h$ 倍。

Google的实验数据令人印象深刻:MQA可以实现10-100倍更小的KV Cache,推理速度提升12倍。PaLM模型率先采用了这一技术。

但MQA的问题很快暴露:

训练不稳定:Shazeer在原始论文中就指出,MQA在长输入任务上容易发散。GQA论文进一步证实,从头训练MQA模型在长输入任务上会出现频繁的loss spike。

质量下降:当模型规模增大时,MQA与MHA的性能差距变得更加明显。这是因为所有头被迫使用相同的Key和Value,表达能力受限。

极端压缩的后遗症:对于7B模型,从32个KV头压缩到1个,信息损失约97%。更大的模型头数更多,损失比例更高。

这正是为什么直到2023年,大多数开源模型仍然使用标准MHA——MQA的代价太大了。

GQA:在效率与质量间寻找平衡

GQA的核心思想非常简单:不让所有头共享一组KV,而是分成若干组,每组共享一组。

假设模型有32个注意力头,GQA-8将它们分成8组,每组4个头共享一组Key和Value:

$$\text{Group } g: \{\text{heads } 4g, 4g+1, 4g+2, 4g+3\} \rightarrow \mathbf{k}_i^{(g)}, \mathbf{v}_i^{(g)}$$

这样,KV Cache从头数32降低到8,节省了75%的内存,同时保留了8倍于MQA的表达能力。

MHA vs MQA vs GQA架构对比

图片来源: GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

上图清晰展示了三种注意力机制的区别:

  • MHA:每个Query头有独立的Key和Value头
  • MQA:所有Query头共享单个Key和Value头
  • GQA:Query头分组,每组共享一个Key和Value头

GQA的数学表示

设Query头数为 $n_h$,KV头数为 $n_{kv}$(组数),每个KV头服务的Query头数为:

$$n_{rep} = \frac{n_h}{n_{kv}}$$

在计算注意力时,Key和Value需要沿头维度复制 $n_{rep}$ 次:

def repeat_kv(x, n_rep):
    batch, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(batch, seq_len, n_kv_heads, n_rep, head_dim)
        .reshape(batch, seq_len, n_kv_heads * n_rep, head_dim)
    )

注意,这个复制操作只发生在注意力计算阶段,缓存的仍然是原始的 $n_{kv}$ 组KV,因此内存占用没有增加。

GQA论文的核心实验

Google团队用T5架构验证了GQA的效果:

模型 注意力类型 推理时间 平均性能
T5-Large MHA 0.37ms 46.0
T5-XXL MHA 1.51ms 47.2
T5-XXL MQA 0.24ms 46.6
T5-XXL GQA-8 0.28ms 47.1

关键发现:

  • GQA-8的推理速度与MQA几乎相同(0.28ms vs 0.24ms)
  • GQA-8的性能与MHA-XXL几乎相同(47.1 vs 47.2)
  • 相比MHA-Large,GQA-XXL既更快又更好

论文还发现了一个有趣的现象:从头训练GQA比从头训练MQA稳定得多。MQA在长输入任务上容易出现训练不稳定,而GQA没有这个问题。

Uptraining:让现有模型快速迁移

GQA论文的另一个贡献是提出了"uptraining"方案:将现有的MHA模型转换为GQA模型,只需要5%的原始训练计算量。

转换过程非常简单:

  1. Mean Pooling:将每组的Key和Value头取平均
  2. 继续预训练:用原始训练配方继续训练5%的步数

实验表明,Mean Pooling比选择单个头或随机初始化效果更好——它最大程度地保留了预训练模型的知识。

主流模型的GQA配置

GQA已经成为现代大模型的标配。以下是几个主流模型的配置:

Llama 2系列

模型 参数量 Query头数 KV头数 压缩比
Llama-2-7B 7B 32 32 1×(MHA)
Llama-2-13B 13B 40 40 1×(MHA)
Llama-2-70B 70B 64 8

有趣的是,Llama 2的7B和13B模型使用标准MHA,只有70B模型使用了GQA。这反映了大模型更严重的内存压力——70B参数模型的权重已经占用了140GB显存(FP16),必须通过GQA节省KV Cache空间。

Llama 3系列

模型 参数量 Query头数 KV头数 压缩比
Llama-3-8B 8B 32 8
Llama-3-70B 70B 64 8

Llama 3全面采用了GQA,即使是8B模型也使用了8个KV头。这表明Meta意识到了GQA对推理效率的重要性。

Mistral系列

Mistral 7B是第一个在7B规模上采用GQA的开源模型:

  • Query头数:32
  • KV头数:8
  • 压缩比:4×

Mistral团队报告称,GQA使其在保持与Llama 2 13B相当性能的同时,推理速度大幅提升。结合滑动窗口注意力(Sliding Window Attention),Mistral 7B实现了出色的效率-性能平衡。

KV Cache内存计算实战

让我们计算几个实际场景下的KV Cache大小。

通用公式

$$\text{KV Cache (bytes)} = 2 \times L \times n_{kv} \times d_h \times \text{seq\_len} \times \text{batch\_size} \times \text{element\_size}$$

其中:

  • $L$:层数
  • $n_{kv}$:KV头数
  • $d_h$:每头维度($= d / n_h$)
  • $element\_size$:数据类型字节数(FP16为2)

具体计算

场景1:Llama 2 70B,序列长度4096,批次大小1

  • $L = 80, n_{kv} = 8, d_h = 128, \text{element\_size} = 2$
  • $\text{KV Cache} = 2 \times 80 \times 8 \times 128 \times 4096 \times 1 \times 2 = 1.31\text{GB}$

场景2:Llama 2 70B使用MHA(假设),相同配置

  • $n_{kv} = 64$
  • $\text{KV Cache} = 1.31\text{GB} \times 8 = 10.5\text{GB}$

GQA节省了约9.2GB显存——这在80GB显存的A100上是决定性的。

场景3:长上下文场景,序列长度32768

  • GQA:$1.31\text{GB} \times 8 = 10.5\text{GB}$
  • MHA:$10.5\text{GB} \times 8 = 84\text{GB}$

在长上下文场景下,GQA的优势更加明显。这也是为什么支持128K上下文的模型(如Llama 3.1)必须使用GQA。

MLA:DeepSeek的低秩压缩方案

如果说GQA是在"多少个头"上做文章,DeepSeek提出的MLA(Multi-head Latent Attention)则是在"如何存储"上进行了根本性创新。

MLA的核心思想

MLA不再存储完整的Key和Value矩阵,而是:

  1. 压缩:将输入投影到低维潜在空间
  2. 缓存:只存储压缩后的潜在向量
  3. 解压:在计算时动态重建Key和Value

数学表示:

$$\mathbf{c}_t^{KV} = \mathbf{h}_t \mathbf{W}^{DKV} \quad \text{(压缩)}$$

$$\mathbf{k}_t^{(s)} = \mathbf{c}_t^{KV} \mathbf{W}_{UK}^{(s)}, \quad \mathbf{v}_t^{(s)} = \mathbf{c}_t^{KV} \mathbf{W}_{UV}^{(s)} \quad \text{(解压)}$$

其中 $\mathbf{W}^{DKV} \in \mathbb{R}^{d \times d_c}$,$d_c \ll n_h \times d_h$。

MLA与GQA/MQA的对比

方法 KV Cache存储 建模容量
MHA $2 \times L \times n_h \times d_h$
MQA $2 \times L \times d_h$
GQA $2 \times L \times n_g \times d_h$ 中高
MLA $L \times d_c$

DeepSeek-V2论文报告,MLA相比MHA减少了93.3%的KV Cache,同时建模容量甚至超过了MHA。这是因为低秩压缩并没有完全丢弃信息,而是以更紧凑的形式保留。

MLA与其他注意力机制的KV Cache对比

图片来源: DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model

上表展示了四种注意力机制的KV Cache大小和建模容量对比。值得注意的是,MLA在KV Cache最小的同时,建模容量反而最高——这打破了"压缩必然损失性能"的直觉。

MLA的RoPE处理

MLA面临一个技术挑战:如何与RoPE(旋转位置编码)兼容?

RoPE需要对Query和Key应用位置相关的旋转矩阵 $\mathcal{R}_i$。如果直接对潜在向量应用旋转,会破坏矩阵吸收的优化。

DeepSeek的解决方案是解耦RoPE:将Key分成两部分:

$$\mathbf{k}_t^{(s)} = \underbrace{(\mathbf{c}_t^{KV} \mathbf{W}_{UK}^{(s)})}_{\text{压缩部分}} \oplus \underbrace{(\mathbf{x}_t \mathbf{W}_{kr} \mathcal{R}_t)}_{\text{位置部分}}$$

压缩部分可以享受低秩存储的优势,位置部分则独立处理以保留相对位置信息。

实践中的权衡与选择

GQA组数的选择

GQA-1等同于MQA,GQA-$n_h$等同于MHA。如何选择合适的组数?

经验法则

  • 小于7B的模型:可以使用MHA(内存压力不大)
  • 7B-30B的模型:GQA-4到GQA-8
  • 30B以上的模型:GQA-8或更高压缩比

Llama 2的选择印证了这一点:7B和13B使用MHA,70B使用GQA-8。

GQA与量化的协同效应

GQA减少的是KV Cache的头数,而量化减少的是每个元素的存储大小。两者可以叠加:

  • GQA-8:节省87.5%
  • INT8量化:再节省50%
  • 总节省:约94%

这也是为什么Llama 2 70B即使使用GQA,在长上下文场景下仍然需要INT8或INT4量化才能在消费级GPU上运行。

推理框架的支持

主流推理框架对GQA的支持情况:

  • vLLM:原生支持,与PagedAttention无缝集成
  • TensorRT-LLM:原生支持,针对NVIDIA GPU优化
  • llama.cpp:支持,但需要正确配置n_kv_heads参数

需要注意,GQA模型的权重文件中,Key和Value投影矩阵的形状是 $(d, n_{kv} \times d_h)$,而不是 $(d, n_h \times d_h)$。加载模型时必须正确识别这一参数。

从架构创新到工程实践

GQA的成功揭示了一个深层次的规律:在大模型时代,架构设计不仅要考虑计算效率,更要考虑内存效率

传统观点认为,模型越大应该有越多的注意力头来提升表达能力。但KV Cache的线性增长使得这个假设在推理阶段失效——更多的头意味着更多的内存带宽压力。

GQA通过"分组共享"的方式,在表达能力和内存效率之间找到了最优平衡。它证明了一个反直觉的结论:让多个头共享相同的Key和Value,并不会显著损失性能。

这个发现对模型设计有深远影响:

训练阶段的冗余:如果推理时可以让多个头共享KV而不损失性能,那么训练时这些头实际上在学什么?是否可以通过更高效的训练范式直接学习共享表示?

头数的意义:传统上,头数被视为模型容量的重要指标。GQA表明,头数的边际效益在达到一定阈值后急剧下降——至少从KV Cache的角度看是这样。

未来方向:MLA的提出表明,GQA可能不是终点。通过低秩压缩和潜在表示,我们可能找到更激进的压缩方案,同时保持甚至提升建模能力。


参考文献

  1. Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150.

  2. Ainslie, J., et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv:2305.13245.

  3. DeepSeek-AI. (2024). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model. arXiv:2405.04434.

  4. Touvron, H., et al. (2023). Llama 2: Open Foundation and Fine-Tuned Chat Models. arXiv:2307.09288.

  5. Jiang, A. Q., et al. (2023). Mistral 7B. mistral.ai/news/announcing-mistral-7b.

  6. Pope, R., et al. (2022). Efficiently Scaling Transformer Inference. arXiv:2211.05102.

  7. Dao, T., et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.