2019年,Noam Shazeer在Google发表了一篇仅4页的论文,标题是《Fast Transformer Decoding: One Write-Head is All You Need》。这篇论文提出了一个激进的方案:让所有注意力头共享同一组Key和Value。结果是KV Cache缩小了几十倍,推理速度提升了10倍以上。
但这个方案有一个致命缺陷——模型质量下降明显,训练也不稳定。
四年后,Google团队在2023年提出了GQA(Grouped-Query Attention),找到了MHA和MQA之间的黄金平衡点。这个技术被Llama 2、Llama 3、Mistral等几乎所有现代开源大模型采用。它究竟解决了什么问题?为什么能让大模型推理既快又准?
KV Cache:大模型推理的隐形内存杀手
理解GQA的价值,首先要理解Transformer推理时的内存瓶颈。
在自回归生成中,每生成一个新token,模型都需要访问之前所有token的Key和Value向量。为了避免重复计算,这些向量被缓存起来,这就是KV Cache。
以Llama 2 70B为例:
- 隐藏层维度:$d = 8192$
- 注意力头数:$n_h = 64$
- 层数:$L = 80$
- KV头数:$n_{kv} = 8$(GQA配置)
对于单个token,每层的KV Cache大小为:
$$\text{KV Cache per layer} = 2 \times n_{kv} \times d_h \times \text{bytes}$$其中 $d_h = d / n_h = 128$ 是每个头的维度,系数2表示Key和Value各一份。使用FP16(2字节),每层每个token的KV Cache为:
$$2 \times 8 \times 128 \times 2 = 4\text{KB}$$对于80层模型,每个token总共需要 $4\text{KB} \times 80 = 320\text{KB}$。
这看起来不大,但考虑实际场景:序列长度4096,批次大小8:
$$320\text{KB} \times 4096 \times 8 = 10.5\text{GB}$$仅仅KV Cache就占用了10GB以上的显存。如果使用标准的MHA(64个KV头),这个数字会膨胀到:
$$320\text{KB} \times \frac{64}{8} = 2.56\text{GB} \times 4096 \times 8 \approx 84\text{GB}$$这就是为什么Llama 2 70B必须使用GQA——标准的MHA根本无法在单张80GB A100上运行长序列推理。
MQA:极端压缩的代价
2019年,Shazeer提出MQA时,思路非常直接:既然KV Cache的瓶颈在于头数太多,为什么不把所有头合并成一个?
标准MHA中,每个注意力头有独立的Key、Value投影:
$$\mathbf{k}_i^{(s)} = \mathbf{x}_i \mathbf{W}_k^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{x}_i \mathbf{W}_v^{(s)}$$其中 $s = 1, \ldots, n_h$ 表示不同的头。
MQA让所有头共享同一组Key和Value:
$$\mathbf{k}_i = \mathbf{x}_i \mathbf{W}_k, \quad \mathbf{v}_i = \mathbf{x}_i \mathbf{W}_v$$这样,KV Cache从头数 $n_h$ 降低到1,内存占用减少 $n_h$ 倍。
Google的实验数据令人印象深刻:MQA可以实现10-100倍更小的KV Cache,推理速度提升12倍。PaLM模型率先采用了这一技术。
但MQA的问题很快暴露:
训练不稳定:Shazeer在原始论文中就指出,MQA在长输入任务上容易发散。GQA论文进一步证实,从头训练MQA模型在长输入任务上会出现频繁的loss spike。
质量下降:当模型规模增大时,MQA与MHA的性能差距变得更加明显。这是因为所有头被迫使用相同的Key和Value,表达能力受限。
极端压缩的后遗症:对于7B模型,从32个KV头压缩到1个,信息损失约97%。更大的模型头数更多,损失比例更高。
这正是为什么直到2023年,大多数开源模型仍然使用标准MHA——MQA的代价太大了。
GQA:在效率与质量间寻找平衡
GQA的核心思想非常简单:不让所有头共享一组KV,而是分成若干组,每组共享一组。
假设模型有32个注意力头,GQA-8将它们分成8组,每组4个头共享一组Key和Value:
$$\text{Group } g: \{\text{heads } 4g, 4g+1, 4g+2, 4g+3\} \rightarrow \mathbf{k}_i^{(g)}, \mathbf{v}_i^{(g)}$$这样,KV Cache从头数32降低到8,节省了75%的内存,同时保留了8倍于MQA的表达能力。

图片来源: GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
上图清晰展示了三种注意力机制的区别:
- MHA:每个Query头有独立的Key和Value头
- MQA:所有Query头共享单个Key和Value头
- GQA:Query头分组,每组共享一个Key和Value头
GQA的数学表示
设Query头数为 $n_h$,KV头数为 $n_{kv}$(组数),每个KV头服务的Query头数为:
$$n_{rep} = \frac{n_h}{n_{kv}}$$在计算注意力时,Key和Value需要沿头维度复制 $n_{rep}$ 次:
def repeat_kv(x, n_rep):
batch, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(batch, seq_len, n_kv_heads, n_rep, head_dim)
.reshape(batch, seq_len, n_kv_heads * n_rep, head_dim)
)
注意,这个复制操作只发生在注意力计算阶段,缓存的仍然是原始的 $n_{kv}$ 组KV,因此内存占用没有增加。
GQA论文的核心实验
Google团队用T5架构验证了GQA的效果:
| 模型 | 注意力类型 | 推理时间 | 平均性能 |
|---|---|---|---|
| T5-Large | MHA | 0.37ms | 46.0 |
| T5-XXL | MHA | 1.51ms | 47.2 |
| T5-XXL | MQA | 0.24ms | 46.6 |
| T5-XXL | GQA-8 | 0.28ms | 47.1 |
关键发现:
- GQA-8的推理速度与MQA几乎相同(0.28ms vs 0.24ms)
- GQA-8的性能与MHA-XXL几乎相同(47.1 vs 47.2)
- 相比MHA-Large,GQA-XXL既更快又更好
论文还发现了一个有趣的现象:从头训练GQA比从头训练MQA稳定得多。MQA在长输入任务上容易出现训练不稳定,而GQA没有这个问题。
Uptraining:让现有模型快速迁移
GQA论文的另一个贡献是提出了"uptraining"方案:将现有的MHA模型转换为GQA模型,只需要5%的原始训练计算量。
转换过程非常简单:
- Mean Pooling:将每组的Key和Value头取平均
- 继续预训练:用原始训练配方继续训练5%的步数
实验表明,Mean Pooling比选择单个头或随机初始化效果更好——它最大程度地保留了预训练模型的知识。
主流模型的GQA配置
GQA已经成为现代大模型的标配。以下是几个主流模型的配置:
Llama 2系列
| 模型 | 参数量 | Query头数 | KV头数 | 压缩比 |
|---|---|---|---|---|
| Llama-2-7B | 7B | 32 | 32 | 1×(MHA) |
| Llama-2-13B | 13B | 40 | 40 | 1×(MHA) |
| Llama-2-70B | 70B | 64 | 8 | 8× |
有趣的是,Llama 2的7B和13B模型使用标准MHA,只有70B模型使用了GQA。这反映了大模型更严重的内存压力——70B参数模型的权重已经占用了140GB显存(FP16),必须通过GQA节省KV Cache空间。
Llama 3系列
| 模型 | 参数量 | Query头数 | KV头数 | 压缩比 |
|---|---|---|---|---|
| Llama-3-8B | 8B | 32 | 8 | 4× |
| Llama-3-70B | 70B | 64 | 8 | 8× |
Llama 3全面采用了GQA,即使是8B模型也使用了8个KV头。这表明Meta意识到了GQA对推理效率的重要性。
Mistral系列
Mistral 7B是第一个在7B规模上采用GQA的开源模型:
- Query头数:32
- KV头数:8
- 压缩比:4×
Mistral团队报告称,GQA使其在保持与Llama 2 13B相当性能的同时,推理速度大幅提升。结合滑动窗口注意力(Sliding Window Attention),Mistral 7B实现了出色的效率-性能平衡。
KV Cache内存计算实战
让我们计算几个实际场景下的KV Cache大小。
通用公式
$$\text{KV Cache (bytes)} = 2 \times L \times n_{kv} \times d_h \times \text{seq\_len} \times \text{batch\_size} \times \text{element\_size}$$其中:
- $L$:层数
- $n_{kv}$:KV头数
- $d_h$:每头维度($= d / n_h$)
- $element\_size$:数据类型字节数(FP16为2)
具体计算
场景1:Llama 2 70B,序列长度4096,批次大小1
- $L = 80, n_{kv} = 8, d_h = 128, \text{element\_size} = 2$
- $\text{KV Cache} = 2 \times 80 \times 8 \times 128 \times 4096 \times 1 \times 2 = 1.31\text{GB}$
场景2:Llama 2 70B使用MHA(假设),相同配置
- $n_{kv} = 64$
- $\text{KV Cache} = 1.31\text{GB} \times 8 = 10.5\text{GB}$
GQA节省了约9.2GB显存——这在80GB显存的A100上是决定性的。
场景3:长上下文场景,序列长度32768
- GQA:$1.31\text{GB} \times 8 = 10.5\text{GB}$
- MHA:$10.5\text{GB} \times 8 = 84\text{GB}$
在长上下文场景下,GQA的优势更加明显。这也是为什么支持128K上下文的模型(如Llama 3.1)必须使用GQA。
MLA:DeepSeek的低秩压缩方案
如果说GQA是在"多少个头"上做文章,DeepSeek提出的MLA(Multi-head Latent Attention)则是在"如何存储"上进行了根本性创新。
MLA的核心思想
MLA不再存储完整的Key和Value矩阵,而是:
- 压缩:将输入投影到低维潜在空间
- 缓存:只存储压缩后的潜在向量
- 解压:在计算时动态重建Key和Value
数学表示:
$$\mathbf{c}_t^{KV} = \mathbf{h}_t \mathbf{W}^{DKV} \quad \text{(压缩)}$$$$\mathbf{k}_t^{(s)} = \mathbf{c}_t^{KV} \mathbf{W}_{UK}^{(s)}, \quad \mathbf{v}_t^{(s)} = \mathbf{c}_t^{KV} \mathbf{W}_{UV}^{(s)} \quad \text{(解压)}$$
其中 $\mathbf{W}^{DKV} \in \mathbb{R}^{d \times d_c}$,$d_c \ll n_h \times d_h$。
MLA与GQA/MQA的对比
| 方法 | KV Cache存储 | 建模容量 |
|---|---|---|
| MHA | $2 \times L \times n_h \times d_h$ | 高 |
| MQA | $2 \times L \times d_h$ | 低 |
| GQA | $2 \times L \times n_g \times d_h$ | 中高 |
| MLA | $L \times d_c$ | 高 |
DeepSeek-V2论文报告,MLA相比MHA减少了93.3%的KV Cache,同时建模容量甚至超过了MHA。这是因为低秩压缩并没有完全丢弃信息,而是以更紧凑的形式保留。

图片来源: DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model
上表展示了四种注意力机制的KV Cache大小和建模容量对比。值得注意的是,MLA在KV Cache最小的同时,建模容量反而最高——这打破了"压缩必然损失性能"的直觉。
MLA的RoPE处理
MLA面临一个技术挑战:如何与RoPE(旋转位置编码)兼容?
RoPE需要对Query和Key应用位置相关的旋转矩阵 $\mathcal{R}_i$。如果直接对潜在向量应用旋转,会破坏矩阵吸收的优化。
DeepSeek的解决方案是解耦RoPE:将Key分成两部分:
$$\mathbf{k}_t^{(s)} = \underbrace{(\mathbf{c}_t^{KV} \mathbf{W}_{UK}^{(s)})}_{\text{压缩部分}} \oplus \underbrace{(\mathbf{x}_t \mathbf{W}_{kr} \mathcal{R}_t)}_{\text{位置部分}}$$压缩部分可以享受低秩存储的优势,位置部分则独立处理以保留相对位置信息。
实践中的权衡与选择
GQA组数的选择
GQA-1等同于MQA,GQA-$n_h$等同于MHA。如何选择合适的组数?
经验法则:
- 小于7B的模型:可以使用MHA(内存压力不大)
- 7B-30B的模型:GQA-4到GQA-8
- 30B以上的模型:GQA-8或更高压缩比
Llama 2的选择印证了这一点:7B和13B使用MHA,70B使用GQA-8。
GQA与量化的协同效应
GQA减少的是KV Cache的头数,而量化减少的是每个元素的存储大小。两者可以叠加:
- GQA-8:节省87.5%
- INT8量化:再节省50%
- 总节省:约94%
这也是为什么Llama 2 70B即使使用GQA,在长上下文场景下仍然需要INT8或INT4量化才能在消费级GPU上运行。
推理框架的支持
主流推理框架对GQA的支持情况:
- vLLM:原生支持,与PagedAttention无缝集成
- TensorRT-LLM:原生支持,针对NVIDIA GPU优化
- llama.cpp:支持,但需要正确配置
n_kv_heads参数
需要注意,GQA模型的权重文件中,Key和Value投影矩阵的形状是 $(d, n_{kv} \times d_h)$,而不是 $(d, n_h \times d_h)$。加载模型时必须正确识别这一参数。
从架构创新到工程实践
GQA的成功揭示了一个深层次的规律:在大模型时代,架构设计不仅要考虑计算效率,更要考虑内存效率。
传统观点认为,模型越大应该有越多的注意力头来提升表达能力。但KV Cache的线性增长使得这个假设在推理阶段失效——更多的头意味着更多的内存带宽压力。
GQA通过"分组共享"的方式,在表达能力和内存效率之间找到了最优平衡。它证明了一个反直觉的结论:让多个头共享相同的Key和Value,并不会显著损失性能。
这个发现对模型设计有深远影响:
训练阶段的冗余:如果推理时可以让多个头共享KV而不损失性能,那么训练时这些头实际上在学什么?是否可以通过更高效的训练范式直接学习共享表示?
头数的意义:传统上,头数被视为模型容量的重要指标。GQA表明,头数的边际效益在达到一定阈值后急剧下降——至少从KV Cache的角度看是这样。
未来方向:MLA的提出表明,GQA可能不是终点。通过低秩压缩和潜在表示,我们可能找到更激进的压缩方案,同时保持甚至提升建模能力。
参考文献
-
Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150.
-
Ainslie, J., et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv:2305.13245.
-
DeepSeek-AI. (2024). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model. arXiv:2405.04434.
-
Touvron, H., et al. (2023). Llama 2: Open Foundation and Fine-Tuned Chat Models. arXiv:2307.09288.
-
Jiang, A. Q., et al. (2023). Mistral 7B. mistral.ai/news/announcing-mistral-7b.
-
Pope, R., et al. (2022). Efficiently Scaling Transformer Inference. arXiv:2211.05102.
-
Dao, T., et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.