显存墙:大模型训练的第一道坎
2020年,OpenAI训练GPT-3时,1750亿参数的模型需要超过350GB的显存——这远超任何单张GPU的容量。三年后,Meta训练Llama 2-70B时,单张A100 80GB显卡甚至无法完整加载模型权重。显存,而非计算能力,已经成为大模型训练的首要瓶颈。
这个问题的根源在于一个简单但残酷的数学事实:训练一个参数量为P的模型,所需显存远超P字节。在混合精度训练配合Adam优化器的标准配置下,每个参数需要约20字节的显存——模型权重(FP16)2字节、FP32主副本4字节、梯度(FP16)2字节、优化器动量4字节、优化器方差4字节。这意味着一个70亿参数的模型,仅静态内存就需要140GB,而这还不包括训练过程中产生的大量中间激活值。
pie title 单参数显存消耗分解(混合精度+Adam)
"FP32主副本" : 4
"优化器动量" : 4
"优化器方差" : 4
"FP16权重" : 2
"FP16梯度" : 2
本文将深入解析过去五年间涌现的显存优化技术,从ZeRO的分片策略到Flash Attention的IO感知算法,揭示这些技术如何将万亿参数模型的训练从理论变为现实。
显存消耗的四重奏
要理解显存优化,首先需要解构训练过程中的显存消耗。GPU显存中的数据可以分为两大类:静态的模型状态和动态的中间激活值。
静态内存:模型状态的固定开销
模型状态包括三个核心组件,它们的大小与参数量成正比:
模型权重:在混合精度训练中,模型权重通常以FP16格式存储,每个参数占用2字节。同时,为了数值稳定性,需要维护一个FP32格式的"主副本",额外占用4字节。这构成了6字节的基础开销。
梯度:反向传播过程中计算的梯度同样需要存储。在混合精度训练中,梯度通常以FP16格式存储,占用2字节。某些框架还会维护FP32格式的梯度副本,进一步增加开销。
优化器状态:这是最容易被低估的内存消耗。以Adam优化器为例,它需要为每个参数维护两个状态变量:一阶动量(momentum)和二阶动量(variance),两者都以FP32格式存储,各占4字节。加上FP32参数副本,优化器状态总计12字节每参数。
将上述组件相加,使用Adam优化器和混合精度训练时,每个参数的静态内存消耗为:
- FP16权重:2字节
- FP32主副本:4字节
- FP16梯度:2字节
- FP32动量:4字节
- FP32方差:4字节
- 总计:16字节每参数(若包含FP32梯度副本则为20字节)
对于一个70亿参数的模型,静态内存需求约为112GB,这已经超过了A100 80GB的容量。
动态内存:激活值的增长规律
激活值是前向传播过程中产生的中间结果,需要保存以供反向传播使用。与静态内存不同,激活值的大小与输入数据的维度密切相关。
以Transformer模型为例,假设输入序列长度为S,批次大小为B,隐藏维度为H,注意力头数为N,每个头的维度为D(H = N × D),模型层数为L。激活值内存的主要来源包括:
自注意力层:
- Q、K、V投影结果:各为 B × N × S × D
- 注意力分数矩阵:B × N × S × S(这是显存消耗的关键,与序列长度成平方关系)
- 注意力输出:B × N × S × D
前馈网络层:
- 第一层输出:B × S × 4H(FFN通常将维度扩展4倍)
- 激活函数输出:B × S × 4H
综合推导,单个Transformer层的激活值内存可表示为:
$$M_{activation}^{layer} = B \times N \times S \times (S + 2D) + 10 \times B \times S \times H$$第一项来自注意力机制(与序列长度平方成正比),第二项来自其他组件(与序列长度线性相关)。对于长序列场景,第一项将主导总内存消耗。
整个模型的激活值内存为:
$$M_{activations} = L \times M_{activation}^{layer}$$这个公式揭示了一个关键洞察:激活值内存与批次大小和序列长度成正比,而与参数量间接相关(通过层数和隐藏维度)。这意味着即使模型参数量不变,增加批次大小或序列长度也会显著增加显存需求。
graph TD
A[GPU显存] --> B[静态内存]
A --> C[动态内存]
B --> D[模型权重<br/>2-6字节/参数]
B --> E[梯度<br/>2-6字节/参数]
B --> F[优化器状态<br/>8-12字节/参数]
C --> G[前向传播激活值<br/>与批次大小和序列长度相关]
C --> H[临时缓冲区<br/>框架开销]
D --> I[总计: 16-20字节/参数]
ZeRO:分片的艺术
面对显存瓶颈,微软在2019年提出了ZeRO(Zero Redundancy Optimizer),这是一个改变游戏规则的优化方案。ZeRO的核心洞察在于:传统的数据并行训练中,每个GPU都保存完整的模型状态副本,造成了大量冗余。
数据并行的冗余问题
在标准的数据并行训练中,假设使用N个GPU:
- 每个GPU保存完整的模型参数副本
- 每个GPU保存完整的梯度副本
- 每个GPU保存完整的优化器状态副本
这意味着模型状态的总存储量是N倍于单GPU的情况。当模型规模增大时,这种冗余变得不可接受。
ZeRO的三阶段分片策略
ZeRO通过分片(Sharding)技术消除冗余,将模型状态分散到不同GPU上。它定义了三个递进的优化阶段:
ZeRO Stage 1:优化器状态分片
这是最温和的优化。仅将优化器状态分片到各GPU,其他组件保持完整副本。具体而言:
- 每个GPU保存 1/N 的优化器状态
- 每个GPU仍保存完整的模型参数和梯度
在参数更新时,每个GPU只更新自己负责的参数分片,然后通过all-gather通信同步更新后的参数。
内存节省计算:对于混合精度+Adam训练,原本每个GPU需要存储12字节的优化器状态每参数。使用ZeRO-1后,每个GPU只需存储 12/N 字节每参数。
ZeRO Stage 2:优化器状态+梯度分片
Stage 2在Stage 1的基础上,进一步分片梯度:
- 每个GPU保存 1/N 的优化器状态
- 每个GPU保存 1/N 的梯度
- 每个GPU仍保存完整的模型参数
反向传播时,使用reduce-scatter操作将梯度汇总到对应的GPU,而不是使用all-reduce。这样每个GPU只存储自己负责的那部分梯度。
内存节省:优化器状态12字节 + 梯度2字节 = 14字节每参数,分片后为 14/N 字节每参数。
ZeRO Stage 3:全面分片
Stage 3是最激进的优化,将所有模型状态都进行分片:
- 每个GPU保存 1/N 的优化器状态
- 每个GPU保存 1/N 的梯度
- 每个GPU保存 1/N 的模型参数
这带来了额外的通信开销:在前向和反向传播过程中,需要通过all-gather操作临时收集完整的参数。但内存节省也是最大的:所有模型状态(约16-20字节每参数)都被分片,每个GPU只需存储 16/N 到 20/N 字节每参数。
ZeRO的内存节省效果
以一个100亿参数的模型为例,使用混合精度训练和Adam优化器,假设使用64个GPU:
| 配置 | 每GPU显存需求 |
|---|---|
| 无优化 | 约160GB |
| ZeRO-1 | 约130GB |
| ZeRO-2 | 约120GB |
| ZeRO-3 | 约40GB |
ZeRO-3将显存需求降低了4倍,使得原本无法训练的模型变得可行。
graph LR
A[原始数据并行] --> B[ZeRO-1<br/>优化器状态分片]
B --> C[ZeRO-2<br/>+梯度分片]
C --> D[ZeRO-3<br/>+参数分片]
E[内存占用] --> F[节省1/6]
F --> G[节省1/3]
G --> H[节省N倍]
ZeRO-Infinity与CPU Offload
2021年,DeepSpeed团队进一步扩展了ZeRO,提出了ZeRO-Infinity。其核心创新是将模型状态卸载到CPU内存甚至NVMe SSD:
CPU Offload机制:
- 将优化器状态存储在CPU内存
- 将优化器计算转移到CPU执行
- GPU只存储当前计算所需的参数分片
这使得在单张V100 32GB GPU上训练100亿参数模型成为可能——前提是系统有足够的CPU内存(约150GB)。
CPU Offload的关键挑战是通信开销。PCIe带宽(约32GB/s)远低于GPU HBM带宽(约2TB/s)。为此,DeepSpeed开发了高度优化的CPU Adam实现,比标准PyTorch实现快5-7倍,以充分利用CPU计算能力。
Gradient Checkpointing:时间换空间
ZeRO解决了静态内存的冗余问题,但动态内存(激活值)的增长仍然是一个挑战。Gradient Checkpointing(梯度检查点)提供了一种时间换空间的策略。
激活值存储的困境
在标准训练中,前向传播产生的所有激活值都需要保存,以供反向传播使用。对于深度为L的网络,这意味着需要存储O(L)规模的中间结果。
对于Transformer,激活值大小与层数线性相关。当模型层数增加时,激活值内存会迅速增长。例如,一个100层的Transformer模型,仅激活值就可能占用数百GB显存。
Chen算法:O(√n)内存复杂度
2016年,Chen等人在论文"Training Deep Nets with Sublinear Memory Cost"中提出了一个精妙的解决方案。其核心思想是:不需要保存所有激活值,只需保存部分"检查点",其余激活值可以在反向传播时重新计算。
算法描述:
- 在前向传播时,只保存部分层的激活值作为检查点
- 反向传播需要某层激活值时,从最近的检查点开始重新计算
- 使用后立即释放重新计算的激活值
如果将网络分成√n个检查点段,每个段有√n层,则:
- 需要保存的检查点数:√n
- 每段重新计算的层数:√n
- 内存复杂度:O(√n)
- 额外计算开销:约33%(一次额外的前向传播)
选择性激活重计算
2022年,NVIDIA研究团队在论文"Reducing Activation Recomputation in Large Transformer Models"中提出了更精细的策略——选择性激活重计算(Selective Activation Recomputation)。
关键洞察:Transformer中不同组件的激活值大小差异巨大。注意力分数矩阵(B × N × S × S)与序列长度平方成正比,而其他组件只与序列长度线性相关。
选择性策略:
- 只重计算注意力相关操作(占激活值的大部分)
- 保留其他组件的激活值(重计算代价高,内存占用小)
这种策略将激活值内存降低5倍,同时将重计算开销降低90%以上。在训练530B参数模型时,选择性重计算比全量重计算快29%。
graph TD
A[前向传播] --> B{保存检查点?}
B -->|是| C[存储激活值]
B -->|否| D[不存储]
C --> E[继续传播]
D --> E
F[反向传播] --> G{需要激活值?}
G -->|已存储| H[直接使用]
G -->|未存储| I[从检查点重计算]
H --> J[计算梯度]
I --> J
J --> K[释放临时激活值]
实践中的权衡
Gradient Checkpointing的核心权衡是计算时间与内存空间的交换。典型配置:
| 策略 | 内存节省 | 时间开销 |
|---|---|---|
| 全量检查点 | 50-60% | 约20% |
| 选择性重计算 | 约5倍 | 约10% |
| 完全不检查点 | 基准 | 基准 |
实际应用中,选择性重计算通常是最优选择——它针对内存密集型操作(注意力)进行重计算,同时保留计算密集型操作(FFN)的激活值。
Flash Attention:IO感知的革命
Gradient Checkpointing通过重计算减少内存使用,但根本问题仍然存在:注意力机制的计算复杂度是O(S²),这意味着注意力分数矩阵的大小随序列长度平方增长。2022年提出的Flash Attention从另一个角度解决了这个问题——优化内存访问模式。
GPU内存层次与IO瓶颈
现代GPU有复杂的内存层次:
- HBM(高带宽内存):容量大(80GB),带宽中等(2TB/s)
- SRAM(片上缓存):容量小(每SM约20MB),带宽极高(约20TB/s)
标准注意力实现的IO模式:
- 从HBM读取Q、K矩阵
- 计算QK^T,写入HBM(S×S矩阵)
- 从HBM读取QK^T,计算softmax,写入HBM
- 从HBM读取softmax结果和V矩阵
- 计算注意力输出,写入HBM
问题在于:中间的S×S矩阵需要频繁写入和读取HBM。对于长序列(如S=8192),这个矩阵可达512MB,而HBM带宽有限,成为性能瓶颈。
Flash Attention的分块策略
Flash Attention的核心思想是:不将大型中间矩阵写入HBM,而是通过分块计算在SRAM中完成所有操作。
算法流程:
- 将Q、K、V分块加载到SRAM
- 在SRAM中计算局部注意力分数
- 使用在线softmax算法(数值稳定性保证)
- 累积局部结果到全局输出
- 只将最终输出写入HBM
关键创新:在线softmax算法允许逐块计算softmax而不需要完整的QK^T矩阵。通过维护运行时的最大值和归一化因子,可以在分块情况下得到精确的softmax结果。
内存复杂度优化
标准注意力:O(S²) 内存(需要存储完整的注意力矩阵)
Flash Attention:O(S) 内存(只存储最终输出和少量中间变量)
对于S=16K的序列:
- 标准注意力:约2GB内存
- Flash Attention:约128KB内存
这不仅是数量级的改进,更使得超长序列训练成为可能。GPT-4的128K上下文窗口、Llama 3的1M上下文,都依赖于Flash Attention的实现。
Flash Attention 2与3的演进
Flash Attention 2(2023):
- 优化并行化策略,减少warp间的同步开销
- 更好的工作分区,提高GPU利用率
- 在A100上实现约70%的理论峰值FLOPS
Flash Attention 3(2024):
- 针对Hopper GPU的新特性优化
- 利用Tensor Core和TMA的异步特性
- 使用ping-pong调度重叠GEMM和softmax
- 支持FP8低精度计算
- 在H100上实现约75%的理论峰值FLOPS(FP16)或接近1.2 PFLOPS(FP8)
flowchart TB
subgraph 标准注意力
A1[读取Q,K] --> A2[计算QK^T<br/>写入HBM]
A2 --> A3[读取QK^T<br/>计算softmax]
A3 --> A4[写入HBM]
A4 --> A5[读取softmax,V<br/>计算输出]
A5 --> A6[写入输出]
end
subgraph Flash Attention
B1[分块加载Q,K,V到SRAM] --> B2[SRAM内计算<br/>在线softmax]
B2 --> B3[累积结果]
B3 --> B4[写入最终输出]
end
A7[IO复杂度: O S²] -.-> 标准注意力
B5[IO复杂度: O S] -.-> Flash Attention
3D并行:跨越单机限制
当模型规模增长到单张GPU无法容纳时,需要将模型本身分布到多张GPU上。3D并行是当前大规模训练的标准架构,结合了三种并行策略。
数据并行(Data Parallelism)
数据并行是最简单的策略:
- 每个GPU持有完整的模型副本
- 输入数据分割到各GPU
- 独立进行前向和反向传播
- 使用all-reduce同步梯度
优点:实现简单,扩展性好 缺点:每个GPU都需要存储完整模型,受单GPU显存限制
结合ZeRO后,数据并行可以突破单GPU显存限制,这也是ZeRO的核心价值。
张量并行(Tensor Parallelism)
张量并行在层内分割参数:
- 将大型矩阵乘法分解到多个GPU
- 每个GPU持有参数的一个分片
- 通过all-reduce同步部分结果
以注意力层为例,假设有N个GPU:
- Q、K、V投影矩阵按列分割,每个GPU负责 N 个head
- 每个GPU独立计算部分注意力输出
- 输出投影矩阵按行分割
- 使用all-reduce汇总最终输出
内存效果:
- 参数内存:降低 N 倍
- 激活值内存:部分降低(注意力输出需要同步)
张量并行需要在每个层进行通信,因此最适合在节点内使用(NVLink高速互联)。跨节点使用时,通信开销可能超过计算收益。
流水线并行(Pipeline Parallelism)
流水线并行在层间分割参数:
- 将模型的连续层分配到不同GPU
- 每个GPU只持有部分层
- 数据以"流水线"方式依次流经各GPU
以4路流水线为例,24层模型:
- GPU 0:第1-6层
- GPU 1:第7-12层
- GPU 2:第13-18层
- GPU 3:第19-24层
内存效果:
- 参数内存:降低 N 倍
- 梯度内存:降低 N 倍(每GPU只存储自己层的梯度)
- 激活值内存:需要存储所有micro-batch的激活值(流水线气泡)
流水线并行的挑战是"气泡"(bubble)——GPU空闲等待的时间。使用GPipe或1F1B调度可以减少气泡比例。
3D并行的组合策略
在大规模训练中,三种并行策略通常组合使用:
典型配置:
- 张量并行(TP):节点内,利用NVLink高速互联
- 流水线并行(PP):跨节点,容忍较高的网络延迟
- 数据并行(DP):剩余维度,配合ZeRO进一步降低内存
并行度计算:
$$DP = \frac{总GPU数}{TP \times PP}$$内存计算:
$$M_{GPU} = \frac{M_{params}}{TP \times PP} + \frac{M_{optimizer}}{N_{total}} + \frac{M_{activations}}{TP}$$以训练一个1000亿参数模型为例,使用128个GPU(8个节点,每节点8 GPU):
- TP = 8(节点内全张量并行)
- PP = 4(跨节点流水线)
- DP = 4(配合ZeRO-1)
每个GPU的静态内存:
- 参数:约25GB / (8 × 4) ≈ 0.8GB
- 优化器:约50GB / 128 ≈ 0.4GB
- 梯度:约6GB / 4 ≈ 1.5GB
这远低于单GPU的显存容量,为激活值留出充足空间。
graph TB
subgraph 3D并行架构
A[数据并行<br/>ZeRO分片] --> B[GPU组1]
A --> C[GPU组2]
A --> D[GPU组3]
A --> E[GPU组4]
B --> F[流水线并行<br/>层间分割]
F --> G[Stage 1]
F --> H[Stage 2]
F --> I[Stage 3]
F --> J[Stage 4]
G --> K[张量并行<br/>层内分割]
K --> L[GPU 0-7]
end
序列并行(Sequence Parallelism)
序列并行是张量并行的扩展,专门针对长序列场景。核心思想是:将序列维度分割到多个GPU,而非在attention层复制完整序列。
在Megatron-LM的实现中:
- LayerNorm和Dropout操作沿序列维度分片
- 注意力层仍然需要完整的序列(或使用Ring Attention)
- 通过all-gather和reduce-scatter同步
内存效果:激活值内存降低 TP 倍,特别是与序列长度相关的部分。
对于超长序列(如100K tokens),序列并行是必需的技术。Meta在训练Llama 3时,就使用了序列并行来支持128K上下文窗口。
混合精度与量化:精度换空间
除了系统级优化,数值精度的选择也直接影响显存消耗。混合精度训练和量化技术提供了另一个维度的优化空间。
混合精度训练的原理
混合精度训练的核心是:不同操作使用不同精度。
高精度操作:
- 参数更新(需要FP32保证数值稳定性)
- 损失缩放(防止梯度下溢)
低精度操作:
- 矩阵乘法(利用Tensor Core加速)
- 前向和反向传播的主体计算
PyTorch的AMP(Automatic Mixed Precision)框架自动处理精度转换:
- 前向传播:FP16计算
- 损失缩放:避免梯度下溢
- 反向传播:FP16计算
- 参数更新:FP32
内存收益:
- 模型权重:FP32 → FP16,节省50%
- 激活值:自动FP16存储
- 优化器状态:仍需FP32
BF16的优势
BF16(Brain Float 16)是Google提出的另一种16位浮点格式:
| 格式 | 符号位 | 指数位 | 尾数位 | 动态范围 |
|---|---|---|---|---|
| FP16 | 1 | 5 | 10 | 有限 |
| BF16 | 1 | 8 | 7 | 与FP32相同 |
| FP32 | 1 | 8 | 23 | 最大 |
BF16的优势在于动态范围与FP32相同,不需要损失缩放。这使得训练更加稳定,在Ampere及更新架构GPU上成为首选。
训练时量化
更激进的策略是在训练时就使用低精度表示:
FP8训练:
- Hopper GPU原生支持FP8计算
- 需要仔细处理量化误差
- Flash Attention 3的incoherent processing技术可降低量化误差2.6倍
INT8优化器:
- 将Adam的动量和方差量化为INT8
- bitsandbytes库提供了实现
- 优化器状态从12字节/参数降低到6字节/参数
内存对比:
- 标准FP32训练:20字节/参数
- 混合精度FP16:16字节/参数
- 混合精度+INT8优化器:10字节/参数
- 全FP8训练:约5字节/参数(实验性)
技术演进的脉络
回顾过去五年,显存优化技术的发展呈现出清晰的脉络:
2019年及之前:
- 混合精度训练成为标配
- 模型并行是大规模训练的主要手段
- 激活重计算开始被关注
2020年:
- ZeRO论文发表,分片优化成为新范式
- GPT-3的训练证明了大规模分布式训练的可行性
2021年:
- ZeRO-Offload将CPU内存纳入显存池
- Megatron-LM成熟,3D并行成为标准架构
2022年:
- Flash Attention革命性地优化了注意力计算
- 选择性激活重计算显著降低重计算开销
- 序列并行支持超长上下文
2023-2024年:
- Flash Attention 2/3持续优化
- FP8训练开始实用化
- MoE(Mixture of Experts)带来新的并行策略
这些技术不是孤立的,而是相互补充、组合使用。一个典型的现代大模型训练配置会同时使用:
- 混合精度训练(基础)
- ZeRO-3(分布式优化器)
- Gradient Checkpointing(激活值优化)
- Flash Attention(注意力优化)
- 3D并行(跨GPU扩展)
实践指南:显存估算与配置
理解了各项技术,如何在实践中应用?以下是显存估算和配置的基本方法。
显存估算公式
静态内存(混合精度+Adam):
$$M_{static} = 16 \times P + \frac{12 \times P}{N_{ZeRO}}$$其中P是参数量,N_{ZeRO}是ZeRO分片的GPU数。
激活值内存:
$$M_{activation} = S \times B \times H \times L \times (10 + \frac{24}{TP} + \frac{5 \times a \times S}{H \times TP})$$其中TP是张量并行度,a是注意力头数。使用选择性重计算后,最后一项可以忽略。
实例分析
以训练一个70亿参数模型为例:
模型配置:
- 参数量:7B
- 层数:32
- 隐藏维度:4096
- 注意力头:32
静态内存(单GPU,无优化):
- 权重(FP16+FP32):6 × 7B = 42GB
- 梯度(FP16):2 × 7B = 14GB
- 优化器状态:12 × 7B = 84GB
- 总计:约140GB
使用ZeRO-3,64 GPU:
- 每GPU参数:42GB / 64 ≈ 0.7GB
- 每GPU梯度:14GB / 64 ≈ 0.2GB
- 每GPU优化器:84GB / 64 ≈ 1.3GB
- 静态总计:约2.2GB/GPU
激活值内存(batch=1, seq=2048, TP=8):
- 无重计算:约28GB/GPU
- 选择性重计算:约6GB/GPU
最终每GPU显存需求:约8-10GB(包含框架开销)
这意味着使用ZeRO-3、选择性重计算和适当的并行配置,可以在A100 80GB GPU上以较大的批次大小训练70亿参数模型。
配置决策树
flowchart TD
A[开始配置] --> B{模型是否单GPU可容纳?}
B -->|是| C[数据并行+ZeRO-1]
B -->|否| D{是否跨节点?}
D -->|否| E[张量并行<br/>节点内GPU全用]
D -->|是| F[张量并行+流水线并行]
C --> G{激活值是否超限?}
E --> G
F --> G
G -->|是| H[启用选择性重计算]
G -->|否| I[继续评估]
H --> J{序列是否超长?}
I --> J
J -->|是| K[启用序列并行]
J -->|否| L[最终配置]
K --> L
未来展望
显存优化技术的发展远未结束。几个值得关注的趋势:
硬件演进:
- HBM4将提供2TB/s+带宽
- 大容量GPU(如H200 141GB)减少分片需求
- 专用AI芯片可能提供不同的内存架构
算法创新:
- 线性注意力机制尝试从根本上解决O(S²)复杂度
- 状态空间模型(如Mamba)提供不同的序列建模范式
- 稀疏注意力减少计算量和内存使用
系统优化:
- 更智能的自动调优系统
- 异构内存管理(GPU+CPU+SSD)
- 容错训练支持(应对大规模集群的硬件故障)
显存优化已经成为大模型训练的核心课题。从ZeRO的分片策略到Flash Attention的IO优化,每一步创新都推动着模型规模的边界。理解这些技术的原理和权衡,是每个大模型从业者的必修课。
参考文献
- Rajbhandari, S., Rasley, J., Ruwase, O., & He, Y. (2020). ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. SC20.
- Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training Deep Nets with Sublinear Memory Cost. arXiv:1604.06174.
- Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
- Korthikanti, V., et al. (2023). Reducing Activation Recomputation in Large Transformer Models. MLSys 2023.
- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691.
- Dao, T., et al. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08608.
- Ren, J., et al. (2021). ZeRO-Offload: Democratizing Billion-Scale Model Training. USENIX ATC 2021.
- Shoeybi, M., et al. (2019). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv:1909.08053.
- Narayanan, D., et al. (2021). Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM. SC21.
- Anthony, Q., Biderman, S., & Schoelkopf, H. (2023). Transformer Math 101. EleutherAI Blog.