把一篇2000字的文章喂给大模型,它能在毫秒级返回摘要。但当你把一篇20000字的长文档扔进去,等待第一个输出的时间可能从几百毫秒延长到几秒甚至更久。更诡异的是,生成后续内容的速度却没有明显下降。
这不是服务器的网络波动,也不是模型突然"变笨"了。这是Transformer架构自诞生之日起就携带的基因——注意力机制的二次方复杂度。
2017年,当Google的研究团队在论文《Attention Is All You Need》中提出Transformer时,他们彻底改变了自然语言处理的范式。自注意力机制让模型能够捕捉序列中任意两个位置之间的依赖关系,无论它们相距多远。这个"超能力"的代价是:计算复杂度随序列长度呈平方级增长。
当你把输入从100个token扩展到1000个token,理论上注意力计算量会增加100倍。这意味着如果你的显存够大,模型仍然可以运行,但你可能要等待更长的时间才能看到第一个输出。
注意力机制的数学本质:为什么是O(n²)
理解这个问题的起点是自注意力机制的数学定义。给定一个输入序列,Transformer首先将其转换为三个矩阵:查询(Query)、键(Key)和值(Value)。对于序列中的每一个位置,模型需要计算它与所有其他位置的相关性。
假设序列长度为 $n$,每个token的隐藏维度为 $d$。注意力计算的公式是:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$这里的关键在于 $QK^T$ 这一步。查询矩阵 $Q$ 的形状是 $(n, d)$,键矩阵 $K$ 的形状也是 $(n, d)$。它们的乘积 $QK^T$ 会产生一个 $(n, n)$ 的注意力分数矩阵。
flowchart LR
subgraph 输入序列
T1["Token 1"]
T2["Token 2"]
T3["Token 3"]
TN["Token n"]
end
subgraph 矩阵变换
Q["Q矩阵<br/>(n × d)"]
K["K矩阵<br/>(n × d)"]
V["V矩阵<br/>(n × d)"]
end
subgraph 注意力计算
QK["QK^T<br/>(n × n)<br/>注意力分数矩阵"]
SM["Softmax<br/>(n × n)<br/>归一化权重"]
OUT["输出<br/>(n × d)"]
end
T1 & T2 & T3 & TN --> Q
T1 & T2 & T3 & TN --> K
T1 & T2 & T3 & TN --> V
Q --> QK
K --> QK
QK --> SM
SM --> OUT
V --> OUT
style QK fill:#ffcccc
style SM fill:#ffcccc
图中红色标注的 $(n, n)$ 矩阵就是复杂度瓶颈的根源。当序列长度从100增长到1000,这个矩阵的大小从10,000个元素增长到1,000,000个元素——100倍的增长。当序列长度从1000增长到10000(大约15000字的中文),矩阵大小会达到1亿元素。
更具体地说,注意力计算的总复杂度是 $O(n^2 \cdot d + n \cdot d^2)$。前一项来自计算注意力分数($n^2$ 次点积,每次点积需要 $d$ 次乘加),后一项来自计算最终的输出($n$ 次矩阵-向量乘法,每次涉及 $d^2$ 次运算)。当序列长度 $n$ 远大于隐藏维度 $d$ 时(比如 $n=8192, d=4096$),前一项主导了计算量。
但复杂度只是故事的一半。另一半是内存访问模式,这往往是现代GPU上真正的性能瓶颈。
GPU内存层次与带宽瓶颈
现代GPU的内存架构是一个分层的金字塔。顶层是寄存器(Registers),速度最快但容量最小,每个线程只能访问自己的寄存器。第二层是共享内存(Shared Memory/SRAM),位于芯片内部,延迟极低,容量通常在几百KB到几MB之间。底层是高带宽内存(High Bandwidth Memory, HBM),也就是我们常说的"显存",容量可达几十GB甚至上百GB,但访问延迟是SRAM的数十倍。
以NVIDIA A100为例,它配备了80GB的HBM2e,带宽高达2039 GB/s。听起来很惊人,但这个数字需要放在具体的工作负载下理解。当GPU执行矩阵运算时,它需要先将数据从HBM加载到SRAM,计算完成后再写回HBM。这个数据搬运的过程往往比计算本身更耗时。
graph TD
subgraph "GPU内存金字塔"
REG["寄存器 Registers<br/>━━━━━━━━━━━<br/>容量: ~256KB总容量<br/>延迟: 1周期<br/>带宽: ~数十TB/s"]
SRAM["共享内存 SRAM<br/>━━━━━━━━━━━<br/>容量: 几MB<br/>延迟: ~30周期<br/>带宽: ~数十TB/s"]
L2["L2缓存<br/>━━━━━━━━━━━<br/>容量: 40-60MB<br/>延迟: ~200周期<br/>带宽: ~数TB/s"]
HBM["高带宽内存 HBM<br/>━━━━━━━━━━━<br/>容量: 40-80GB<br/>延迟: ~500周期<br/>带宽: 1-2 TB/s"]
end
REG -->|"最快"| SRAM
SRAM -->|"快"| L2
L2 -->|"较慢"| HBM
style REG fill:#ccffcc
style HBM fill:#ffcccc
算术强度(Arithmetic Intensity) 是理解这个问题的关键指标,定义为每字节内存访问所对应的浮点运算次数:
$$\text{Arithmetic Intensity} = \frac{\text{Total FLOPs}}{\text{Total Bytes Accessed}}$$当算术强度低时,处理器大部分时间在等待数据到达,这种状态被称为内存带宽受限(Memory-Bandwidth Bound)。当算术强度高时,处理器满负荷运转,这种状态被称为计算受限(Compute-Bound)。
graph LR
subgraph "Roofline性能模型"
X[算术强度 FLOPs/Byte]
Y[性能 FLOPs/s]
MB["内存带宽受限区<br/>━━━━━━━━━━━<br/>性能 = 算术强度 × 带宽<br/>瓶颈: HBM带宽"]
CB["计算受限区<br/>━━━━━━━━━━━<br/>性能 = 峰值算力<br/>瓶颈: GPU计算单元"]
X --> MB
X --> CB
end
MB -->|"低算术强度操作<br/>如: 注意力矩阵Softmax"| EXAMPLE1["示例:<br/>Decode阶段<br/>矩阵-向量乘法"]
CB -->|"高算术强度操作<br/>如: 大矩阵乘法"| EXAMPLE2["示例:<br/>Prefill阶段<br/>矩阵-矩阵乘法"]
style MB fill:#ffcccc
style CB fill:#ccffcc
注意力机制恰好是一个低算术强度的操作。问题的核心在于:注意力分数矩阵 $(n, n)$ 需要完整地存储在HBM中,然后被反复读写。在标准实现中,计算Softmax需要对整个注意力矩阵进行全局归约,这导致了大量的HBM访问。
graph TD
subgraph "标准注意力数据流"
A[HBM: Q, K, V矩阵<br/>输入数据]
B[SRAM: 加载Q块、K块]
C[计算: QK^T<br/>注意力分数]
D[HBM: 写回注意力矩阵<br/>n×n元素]
E[SRAM: 重新加载注意力矩阵]
F[计算: Softmax归一化]
G[HBM: 写回归一化结果]
H[SRAM: 加载结果和V矩阵]
I[计算: 加权求和]
J[HBM: 写回最终输出]
end
A -->|加载| B
B -->|计算| C
C -->|写回| D
D -->|重新加载| E
E -->|计算| F
F -->|写回| G
G -->|加载| H
H -->|计算| I
I -->|写回| J
style D fill:#ffcccc
style G fill:#ffcccc
style J fill:#ffcccc
图中红色标注的多次HBM读写是性能瓶颈的根源。当序列长度增加时,这个 $(n, n)$ 矩阵呈平方级增长,内存压力急剧上升。每次写回和重新加载都需要传输 $n^2$ 个元素,带宽消耗惊人。
Prefill与Decode:两种截然不同的阶段
LLM推理过程分为两个阶段,它们在计算特征上有着本质的区别。理解这个区别,是优化推理性能的关键。
flowchart TD
subgraph "LLM推理两阶段"
INPUT[输入序列<br/>"请总结这篇文章..."]
PREFILL["Prefill阶段<br/>━━━━━━━━━━━<br/>• 一次性处理所有输入token<br/>• 并行计算所有KV向量<br/>• 矩阵-矩阵乘法<br/>• 计算受限<br/>• 影响TTFT首token延迟"]
DECODE["Decode阶段<br/>━━━━━━━━━━━<br/>• 逐个生成输出token<br/>• 增量式计算<br/>• 矩阵-向量乘法<br/>• 内存带宽受限<br/>• 影响TPOT每token时间"]
OUTPUT[输出序列<br/>"这篇文章讨论了..."]
end
INPUT --> PREFILL
PREFILL -->|首个输出token| DECODE
DECODE -->|循环生成| DECODE
DECODE --> OUTPUT
style PREFILL fill:#ccffcc
style DECODE fill:#ffcccc
Prefill阶段发生在推理的最开始。模型需要一次性处理整个输入序列,为每个输入token计算Key和Value向量,并将它们存储在KV Cache中。这是一个高度并行的过程——所有输入token可以同时计算。此时的注意力计算是矩阵-矩阵乘法,算术强度高,GPU利用率高,性能接近计算受限状态。
Decode阶段紧随其后。模型开始逐个生成输出token。每生成一个新token,模型需要:
- 用这个新token的Query向量与之前所有token的Key向量计算注意力分数
- 使用注意力分数对所有Value向量进行加权求和
- 将新token的Key和Value向量追加到KV Cache中
这是一个串行的、增量式的过程。每生成一个token,只需要进行一次矩阵-向量乘法(而不是矩阵-矩阵乘法)。算术强度骤降,性能变成了内存带宽受限。
graph LR
subgraph "Prefill阶段"
P1[输入: n个token]
P2[计算: Q,K,V矩阵<br/>n×n注意力]
P3[输出: KV Cache<br/>+ 第一个token]
P1 --> P2 --> P3
end
subgraph "Decode阶段 - 每个step"
D1[输入: 1个新token]
D2[计算: 与历史KV的<br/>注意力1×n]
D3[输出: 1个token<br/>+ 更新KV Cache]
D1 --> D2 --> D3
end
P3 --> D1
D3 -->|"循环"| D1
style P2 fill:#ccffcc
style D2 fill:#ffcccc
研究表明,在Prefill阶段,几乎所有的延迟都可以归因于计算;而在Decode阶段,不到20%的延迟来自计算,超过80%的时间花在等待数据传输上。
这解释了开篇提到的现象:长输入主要影响的是Prefill阶段的时长(首token延迟,TTFT),而对Decode阶段的token生成速度(每token生成时间,TPOT)影响较小。因为Prefill需要处理整个输入序列,而Decode只需要处理一个新token。
两个阶段的不同特征也催生了不同的优化策略:
- Prefill优化:减少计算量,如FlashAttention、稀疏注意力
- Decode优化:减少内存访问,如KV Cache压缩、GQA
KV Cache:用空间换时间的经典策略
KV Cache是Transformer推理中最重要的优化之一。它的核心思想很简单:既然Decode阶段每生成一个token都需要用到之前所有token的Key和Value向量,为什么不把它们缓存起来,避免重复计算?
没有KV Cache时,生成第 $t$ 个token需要重新计算前 $t-1$ 个token的Key和Value向量,复杂度是 $O(t^2)$。有了KV Cache,前 $t-1$ 个token的Key和Value已经存储,只需要计算第 $t$ 个token的Key和Value,然后进行一次注意力计算即可。复杂度降到了 $O(t)$。
但这个优化的代价是内存消耗。KV Cache的大小公式是:
$$\text{KV Cache Size} = 2 \times L \times n_{\text{heads}} \times d_{\text{head}} \times s \times \text{sizeof}(\text{dtype})$$其中 $L$ 是层数,$n_{\text{heads}}$ 是注意力头数,$d_{\text{head}}$ 是每个头的维度,$s$ 是序列长度。
graph TD
subgraph "KV Cache内存消耗"
SEQ["序列长度 s"]
LAYER["层数 L"]
HEAD["注意力头数 n_heads"]
DIM["每头维度 d_head"]
DTYPE["数据类型 dtype"]
KVC["KV Cache大小"]
end
SEQ --> KVC
LAYER --> KVC
HEAD --> KVC
DIM --> KVC
DTYPE --> KVC
KVC -->|"公式"| FORMULA["2 × L × n_heads × d_head × s × dtype_size"]
subgraph "示例: Llama-2-70B"
EX["L=80, n_heads=64<br/>d_head=128, FP16<br/>s=4096"]
RESULT["KV Cache = 53.7 GB"]
end
EX --> RESULT
style KVC fill:#ffcccc
以Llama-2-70B为例:$L=80$, $n_{\text{heads}}=64$, $d_{\text{head}}=128$。假设使用FP16(2字节),序列长度为4096:
$$\text{KV Cache Size} = 2 \times 80 \times 64 \times 128 \times 4096 \times 2 = 53.7 \text{ GB}$$这个数字对于单张GPU来说相当可观。当序列长度扩展到32K时,KV Cache就需要超过400GB——超出了几乎所有消费级GPU的内存容量。
KV Cache带来的另一个问题是内存碎片。当多个请求并发处理时,每个请求的序列长度各不相同,而且会动态增长。传统实现会为每个请求预分配最大序列长度的连续内存,这导致了巨大的浪费——研究表明传统系统中只有20-38%的KV Cache内存被实际使用。
graph TD
subgraph "传统KV Cache分配"
REQ1["请求1<br/>实际长度: 500<br/>预分配: 4096"]
REQ2["请求2<br/>实际长度: 1000<br/>预分配: 4096"]
REQ3["请求3<br/>实际长度: 200<br/>预分配: 4096"]
WASTE1["浪费: 3596块"]
WASTE2["浪费: 3096块"]
WASTE3["浪费: 3896块"]
REQ1 --> WASTE1
REQ2 --> WASTE2
REQ3 --> WASTE3
end
subgraph "PagedAttention"
BLOCKS["固定大小块池"]
BT1["块表1<br/>映射: 逻辑→物理"]
BT2["块表2"]
BT3["块表3"]
UTIL["利用率: ~100%<br/>按需分配,无碎片"]
BLOCKS --> BT1 & BT2 & BT3
BT1 & BT2 & BT3 --> UTIL
end
style WASTE1 fill:#ffcccc
style WASTE2 fill:#ffcccc
style WASTE3 fill:#ffcccc
style UTIL fill:#ccffcc
这就是PagedAttention诞生的背景。借鉴操作系统的虚拟内存管理,PagedAttention将KV Cache分割成固定大小的块(Block),每个块可以存储固定数量的token。这些块不需要连续存储,通过一个块表(Block Table)来映射逻辑地址到物理地址。这大大减少了内存碎片,将KV Cache利用率提升到接近100%。
FlashAttention:重新思考IO感知
如果说KV Cache是"用空间换时间"的典范,那么FlashAttention则是"用计算换IO"的革命。
FlashAttention的核心洞察是:注意力计算的性能瓶颈不在于计算本身,而在于内存访问。与其优化计算,不如优化数据搬运。
标准注意力实现需要多次在HBM和SRAM之间传输数据。对于序列长度 $n$,这需要 $O(n^2)$ 次HBM读写。
FlashAttention采用了分块计算(Tiling) 策略:将Q、K、V矩阵分割成小块,每块大小刚好能放入SRAM。然后,对于每个Q块,遍历所有K块和V块,在SRAM内完成注意力计算,只将最终结果写回HBM。
关键的技术挑战是Softmax计算。标准Softmax需要知道整个注意力分数矩阵才能归一化。FlashAttention采用了在线Softmax(Online Softmax) 技术,通过巧妙的数学变换,使得可以逐块计算Softmax,而无需一次性访问整个矩阵。
graph LR
subgraph "FlashAttention分块策略"
QBLOCKS["Q矩阵<br/>分成小块"]
KBLOCKS["K矩阵<br/>分成小块"]
VBLOCKS["V矩阵<br/>分成小块"]
SRAM["SRAM<br/>计算区域"]
RESULT["最终结果"]
QBLOCKS -->|"加载Q块"| SRAM
KBLOCKS -->|"循环加载K块"| SRAM
VBLOCKS -->|"循环加载V块"| SRAM
SRAM -->|"完整计算<br/>QK^T+Softmax+×V"| RESULT
end
style SRAM fill:#ccffcc
FlashAttention将注意力计算的HBM访问次数从 $O(n^2)$ 降到了 $O(n)$,实现了显著的加速。FlashAttention-2进一步优化了并行策略,在A100上相比标准实现提速可达2-4倍。
graph TD
subgraph "复杂度对比"
STD["标准注意力<br/>━━━━━━━━━━━<br/>HBM访问: O n²<br/>内存: O n²"]
FLASH["FlashAttention<br/>━━━━━━━━━━━<br/>HBM访问: O n<br/>内存: O n"]
end
subgraph "性能提升"
SEQ["序列长度"]
SPEED["加速比"]
S1["n=1K: 1.5-2x"]
S2["n=4K: 2-3x"]
S3["n=16K: 3-4x"]
end
STD --> FLASH
SEQ --> S1 & S2 & S3
S1 & S2 & S3 --> SPEED
style STD fill:#ffcccc
style FLASH fill:#ccffcc
更多优化策略:从架构到系统
除了FlashAttention和KV Cache,研究者们还提出了多种优化策略来缓解序列长度带来的性能问题。
分组查询注意力(Grouped Query Attention, GQA) 是一种架构级别的优化。在标准的多头注意力(MHA)中,每个注意力头都有独立的Key和Value向量。GQA将多个Query头共享同一组Key和Value头,从而减少KV Cache的大小。
graph TD
subgraph "MHA: 多头注意力"
Q1["Q头1"]
Q2["Q头2"]
Q3["Q头3"]
Q4["Q头4"]
K1["K头1"]
K2["K头2"]
K3["K头3"]
K4["K头4"]
V1["V头1"]
V2["V头2"]
V3["V头3"]
V4["V头4"]
Q1 --> K1 --> V1
Q2 --> K2 --> V2
Q3 --> K3 --> V3
Q4 --> K4 --> V4
end
subgraph "GQA: 分组查询注意力"
GQ1["Q头1"]
GQ2["Q头2"]
GQ3["Q头3"]
GQ4["Q头4"]
GK["共享K头"]
GV["共享V头"]
GQ1 & GQ2 --> GK --> GV
GQ3 & GQ4 --> GK --> GV
end
subgraph "MQA: 多查询注意力"
MQ1["Q头1"]
MQ2["Q头2"]
MQ3["Q头3"]
MQ4["Q头4"]
MK["单个K头"]
MV["单个V头"]
MQ1 & MQ2 & MQ3 & MQ4 --> MK --> MV
end
style K1 fill:#ffcccc
style K2 fill:#ffcccc
style K3 fill:#ffcccc
style K4 fill:#ffcccc
style GK fill:#ffffcc
style MK fill:#ccffcc
Llama-2-70B使用GQA-8,即每8个Query头共享1组Key和Value头,KV Cache大小减少了87.5%。GQA是多头注意力(MHA)和多查询注意力(MQA)之间的折中。MQA极端地让所有Query头共享一组Key和Value,KV Cache最小但性能损失较大。GQA在内存效率和模型质量之间取得了更好的平衡。
连续批处理(Continuous Batching) 是系统层面的优化。传统批处理需要等待批次中所有请求完成才能开始新批次,导致GPU资源浪费。连续批处理采用迭代级调度:每当一个请求生成完毕,立即从队列中取出新请求填补空位。
gantt
title 传统批处理 vs 连续批处理
dateFormat X
axisFormat %s
section 传统批处理
请求A Prefill :a1, 0, 2
请求A Decode :a2, after a1, 4
请求B Prefill :b1, 0, 2
请求B Decode :b2, after b1, 6
等待A完成 :crit, wait1, after a2, 2
section 连续批处理
请求A Prefill :ca1, 0, 2
请求A Decode :ca2, after ca1, 4
请求B Prefill :cb1, after ca1, 2
请求B Decode :cb2, after cb1, 6
请求C Prefill :cc1, after ca2, 2
请求C Decode :cc2, after cc1, 3
这可以显著提升吞吐量,实验显示可达23倍的提升。
Ring Attention 是面向分布式场景的解决方案。当序列长度超过单张GPU的内存容量时,可以将序列分割到多个GPU上,通过环形通信模式传递Key和Value块。这使得训练和推理百万级token的超长序列成为可能。
graph LR
subgraph "Ring Attention环形通信"
GPU1["GPU 1<br/>K1,V1块"]
GPU2["GPU 2<br/>K2,V2块"]
GPU3["GPU 3<br/>K3,V3块"]
GPU4["GPU 4<br/>K4,V4块"]
GPU1 -->|"传递K,V"| GPU2
GPU2 -->|"传递K,V"| GPU3
GPU3 -->|"传递K,V"| GPU4
GPU4 -->|"传递K,V"| GPU1
end
Q["Q矩阵分布在各GPU"]
Q --> GPU1 & GPU2 & GPU3 & GPU4
style GPU1 fill:#ccffcc
style GPU2 fill:#ffffcc
style GPU3 fill:#ffcccc
style GPU4 fill:#ccccff
实践指南:不同场景下的策略选择
面对具体的应用场景,如何选择合适的优化策略?
graph TD
subgraph "场景分类与优化策略"
SHORT["短序列交互<br/>输入 < 4K<br/>━━━━━━━━━━━<br/>• 重点关注KV Cache效率<br/>• 选择GQA模型<br/>• 使用vLLM/TensorRT-LLM"]
MEDIUM["长文档处理<br/>输入 4K-32K<br/>━━━━━━━━━━━<br/>• Prefill是主要瓶颈<br/>• FlashAttention必须<br/>• 关注显存容量"]
LONG["超长上下文<br/>输入 > 32K<br/>━━━━━━━━━━━<br/>• 单GPU可能无法处理<br/>• 考虑Ring Attention<br/>• 分层处理策略"]
CONCURRENT["高并发服务<br/>━━━━━━━━━━━<br/>• 重点关注吞吐量<br/>• 连续批处理核心<br/>• Prefill/Decode混合执行"]
end
style SHORT fill:#ccffcc
style MEDIUM fill:#ffffcc
style LONG fill:#ffcccc
style CONCURRENT fill:#ccccff
从性能指标的角度,需要区分优化目标:
- TTFT(首Token延迟):主要受Prefill影响,优化重点是FlashAttention、并行计算
- TPOT(每Token生成时间):主要受Decode影响,优化重点是KV Cache效率、GQA
- 吞吐量:受批处理效率影响,优化重点是连续批处理、内存管理
结语
序列长度对Transformer推理的影响远不止"变慢"这么简单。它触及了注意力机制的核心设计:通过 $O(n^2)$ 的计算复杂度来捕捉序列中的全局依赖关系。这个设计在处理短序列时优雅高效,但在面对长文本时暴露了GPU内存带宽的瓶颈。
从KV Cache的"空间换时间",到FlashAttention的"计算换IO",再到GQA的架构改进和Continuous Batching的系统优化,研究社区在过去几年中构建了一套完整的工具链来应对这个挑战。这些优化不是孤立的,它们相互配合,共同构成了现代LLM推理引擎的基础。
当你下次面对一个长文档需要等待几秒钟才能看到第一个输出时,你会知道这段时间里GPU正在做什么:它在HBM和SRAM之间搬运着数以亿计的注意力分数,在计算与内存带宽的夹缝中,努力让这个对全局依赖关系的美好愿景成为现实。
参考文献
- Vaswani, A., et al. (2017). Attention Is All You Need. NeurIPS 2017.
- Dao, T., et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691.
- Kwon, W., et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. SOSP 2023.
- Ainslie, J., et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv:2305.13245.
- Liu, H., et al. (2024). Ring Attention with Blockwise Transformers for Near-Infinite Context. ICLR 2024.
- Yu, G.I., et al. (2022). Orca: A Distributed Serving System for Transformer-Based Generative Models. OSDI 2022.
- NVIDIA. (2023). Mastering LLM Techniques: Inference Optimization. NVIDIA Developer Blog.
- BentoML. (2024). Key Metrics for LLM Inference. LLM Inference Handbook.
- Databricks. (2023). LLM Inference Performance Engineering: Best Practices.