显存墙:大模型训练的第一道坎

2020年,OpenAI训练GPT-3时,1750亿参数的模型需要超过350GB的显存——这远超任何单张GPU的容量。三年后,Meta训练Llama 2-70B时,单张A100 80GB显卡甚至无法完整加载模型权重。显存,而非计算能力,已经成为大模型训练的首要瓶颈。

这个问题的根源在于一个简单但残酷的数学事实:训练一个参数量为P的模型,所需显存远超P字节。在混合精度训练配合Adam优化器的标准配置下,每个参数需要约20字节的显存——模型权重(FP16)2字节、FP32主副本4字节、梯度(FP16)2字节、优化器动量4字节、优化器方差4字节。这意味着一个70亿参数的模型,仅静态内存就需要140GB,而这还不包括训练过程中产生的大量中间激活值。

pie title 单参数显存消耗分解(混合精度+Adam)
    "FP32主副本" : 4
    "优化器动量" : 4
    "优化器方差" : 4
    "FP16权重" : 2
    "FP16梯度" : 2

本文将深入解析过去五年间涌现的显存优化技术,从ZeRO的分片策略到Flash Attention的IO感知算法,揭示这些技术如何将万亿参数模型的训练从理论变为现实。

显存消耗的四重奏

要理解显存优化,首先需要解构训练过程中的显存消耗。GPU显存中的数据可以分为两大类:静态的模型状态和动态的中间激活值。

静态内存:模型状态的固定开销

模型状态包括三个核心组件,它们的大小与参数量成正比:

模型权重:在混合精度训练中,模型权重通常以FP16格式存储,每个参数占用2字节。同时,为了数值稳定性,需要维护一个FP32格式的"主副本",额外占用4字节。这构成了6字节的基础开销。

梯度:反向传播过程中计算的梯度同样需要存储。在混合精度训练中,梯度通常以FP16格式存储,占用2字节。某些框架还会维护FP32格式的梯度副本,进一步增加开销。

优化器状态:这是最容易被低估的内存消耗。以Adam优化器为例,它需要为每个参数维护两个状态变量:一阶动量(momentum)和二阶动量(variance),两者都以FP32格式存储,各占4字节。加上FP32参数副本,优化器状态总计12字节每参数。

将上述组件相加,使用Adam优化器和混合精度训练时,每个参数的静态内存消耗为:

  • FP16权重:2字节
  • FP32主副本:4字节
  • FP16梯度:2字节
  • FP32动量:4字节
  • FP32方差:4字节
  • 总计:16字节每参数(若包含FP32梯度副本则为20字节)

对于一个70亿参数的模型,静态内存需求约为112GB,这已经超过了A100 80GB的容量。

动态内存:激活值的增长规律

激活值是前向传播过程中产生的中间结果,需要保存以供反向传播使用。与静态内存不同,激活值的大小与输入数据的维度密切相关。

以Transformer模型为例,假设输入序列长度为S,批次大小为B,隐藏维度为H,注意力头数为N,每个头的维度为D(H = N × D),模型层数为L。激活值内存的主要来源包括:

自注意力层

  • Q、K、V投影结果:各为 B × N × S × D
  • 注意力分数矩阵:B × N × S × S(这是显存消耗的关键,与序列长度成平方关系)
  • 注意力输出:B × N × S × D

前馈网络层

  • 第一层输出:B × S × 4H(FFN通常将维度扩展4倍)
  • 激活函数输出:B × S × 4H

综合推导,单个Transformer层的激活值内存可表示为:

$$M_{activation}^{layer} = B \times N \times S \times (S + 2D) + 10 \times B \times S \times H$$

第一项来自注意力机制(与序列长度平方成正比),第二项来自其他组件(与序列长度线性相关)。对于长序列场景,第一项将主导总内存消耗。

整个模型的激活值内存为:

$$M_{activations} = L \times M_{activation}^{layer}$$

这个公式揭示了一个关键洞察:激活值内存与批次大小和序列长度成正比,而与参数量间接相关(通过层数和隐藏维度)。这意味着即使模型参数量不变,增加批次大小或序列长度也会显著增加显存需求。

graph TD
    A[GPU显存] --> B[静态内存]
    A --> C[动态内存]
    
    B --> D[模型权重<br/>2-6字节/参数]
    B --> E[梯度<br/>2-6字节/参数]
    B --> F[优化器状态<br/>8-12字节/参数]
    
    C --> G[前向传播激活值<br/>与批次大小和序列长度相关]
    C --> H[临时缓冲区<br/>框架开销]
    
    D --> I[总计: 16-20字节/参数]

ZeRO:分片的艺术

面对显存瓶颈,微软在2019年提出了ZeRO(Zero Redundancy Optimizer),这是一个改变游戏规则的优化方案。ZeRO的核心洞察在于:传统的数据并行训练中,每个GPU都保存完整的模型状态副本,造成了大量冗余。

数据并行的冗余问题

在标准的数据并行训练中,假设使用N个GPU:

  • 每个GPU保存完整的模型参数副本
  • 每个GPU保存完整的梯度副本
  • 每个GPU保存完整的优化器状态副本

这意味着模型状态的总存储量是N倍于单GPU的情况。当模型规模增大时,这种冗余变得不可接受。

ZeRO的三阶段分片策略

ZeRO通过分片(Sharding)技术消除冗余,将模型状态分散到不同GPU上。它定义了三个递进的优化阶段:

ZeRO Stage 1:优化器状态分片

这是最温和的优化。仅将优化器状态分片到各GPU,其他组件保持完整副本。具体而言:

  • 每个GPU保存 1/N 的优化器状态
  • 每个GPU仍保存完整的模型参数和梯度

在参数更新时,每个GPU只更新自己负责的参数分片,然后通过all-gather通信同步更新后的参数。

内存节省计算:对于混合精度+Adam训练,原本每个GPU需要存储12字节的优化器状态每参数。使用ZeRO-1后,每个GPU只需存储 12/N 字节每参数。

ZeRO Stage 2:优化器状态+梯度分片

Stage 2在Stage 1的基础上,进一步分片梯度:

  • 每个GPU保存 1/N 的优化器状态
  • 每个GPU保存 1/N 的梯度
  • 每个GPU仍保存完整的模型参数

反向传播时,使用reduce-scatter操作将梯度汇总到对应的GPU,而不是使用all-reduce。这样每个GPU只存储自己负责的那部分梯度。

内存节省:优化器状态12字节 + 梯度2字节 = 14字节每参数,分片后为 14/N 字节每参数。

ZeRO Stage 3:全面分片

Stage 3是最激进的优化,将所有模型状态都进行分片:

  • 每个GPU保存 1/N 的优化器状态
  • 每个GPU保存 1/N 的梯度
  • 每个GPU保存 1/N 的模型参数

这带来了额外的通信开销:在前向和反向传播过程中,需要通过all-gather操作临时收集完整的参数。但内存节省也是最大的:所有模型状态(约16-20字节每参数)都被分片,每个GPU只需存储 16/N 到 20/N 字节每参数。

ZeRO的内存节省效果

以一个100亿参数的模型为例,使用混合精度训练和Adam优化器,假设使用64个GPU:

配置 每GPU显存需求
无优化 约160GB
ZeRO-1 约130GB
ZeRO-2 约120GB
ZeRO-3 约40GB

ZeRO-3将显存需求降低了4倍,使得原本无法训练的模型变得可行。

graph LR
    A[原始数据并行] --> B[ZeRO-1<br/>优化器状态分片]
    B --> C[ZeRO-2<br/>+梯度分片]
    C --> D[ZeRO-3<br/>+参数分片]
    
    E[内存占用] --> F[节省1/6]
    F --> G[节省1/3]
    G --> H[节省N倍]

ZeRO-Infinity与CPU Offload

2021年,DeepSpeed团队进一步扩展了ZeRO,提出了ZeRO-Infinity。其核心创新是将模型状态卸载到CPU内存甚至NVMe SSD:

CPU Offload机制

  • 将优化器状态存储在CPU内存
  • 将优化器计算转移到CPU执行
  • GPU只存储当前计算所需的参数分片

这使得在单张V100 32GB GPU上训练100亿参数模型成为可能——前提是系统有足够的CPU内存(约150GB)。

CPU Offload的关键挑战是通信开销。PCIe带宽(约32GB/s)远低于GPU HBM带宽(约2TB/s)。为此,DeepSpeed开发了高度优化的CPU Adam实现,比标准PyTorch实现快5-7倍,以充分利用CPU计算能力。

Gradient Checkpointing:时间换空间

ZeRO解决了静态内存的冗余问题,但动态内存(激活值)的增长仍然是一个挑战。Gradient Checkpointing(梯度检查点)提供了一种时间换空间的策略。

激活值存储的困境

在标准训练中,前向传播产生的所有激活值都需要保存,以供反向传播使用。对于深度为L的网络,这意味着需要存储O(L)规模的中间结果。

对于Transformer,激活值大小与层数线性相关。当模型层数增加时,激活值内存会迅速增长。例如,一个100层的Transformer模型,仅激活值就可能占用数百GB显存。

Chen算法:O(√n)内存复杂度

2016年,Chen等人在论文"Training Deep Nets with Sublinear Memory Cost"中提出了一个精妙的解决方案。其核心思想是:不需要保存所有激活值,只需保存部分"检查点",其余激活值可以在反向传播时重新计算。

算法描述:

  1. 在前向传播时,只保存部分层的激活值作为检查点
  2. 反向传播需要某层激活值时,从最近的检查点开始重新计算
  3. 使用后立即释放重新计算的激活值

如果将网络分成√n个检查点段,每个段有√n层,则:

  • 需要保存的检查点数:√n
  • 每段重新计算的层数:√n
  • 内存复杂度:O(√n)
  • 额外计算开销:约33%(一次额外的前向传播)

选择性激活重计算

2022年,NVIDIA研究团队在论文"Reducing Activation Recomputation in Large Transformer Models"中提出了更精细的策略——选择性激活重计算(Selective Activation Recomputation)。

关键洞察:Transformer中不同组件的激活值大小差异巨大。注意力分数矩阵(B × N × S × S)与序列长度平方成正比,而其他组件只与序列长度线性相关。

选择性策略:

  • 只重计算注意力相关操作(占激活值的大部分)
  • 保留其他组件的激活值(重计算代价高,内存占用小)

这种策略将激活值内存降低5倍,同时将重计算开销降低90%以上。在训练530B参数模型时,选择性重计算比全量重计算快29%。

graph TD
    A[前向传播] --> B{保存检查点?}
    B -->|是| C[存储激活值]
    B -->|否| D[不存储]
    C --> E[继续传播]
    D --> E
    
    F[反向传播] --> G{需要激活值?}
    G -->|已存储| H[直接使用]
    G -->|未存储| I[从检查点重计算]
    H --> J[计算梯度]
    I --> J
    J --> K[释放临时激活值]

实践中的权衡

Gradient Checkpointing的核心权衡是计算时间与内存空间的交换。典型配置:

策略 内存节省 时间开销
全量检查点 50-60% 约20%
选择性重计算 约5倍 约10%
完全不检查点 基准 基准

实际应用中,选择性重计算通常是最优选择——它针对内存密集型操作(注意力)进行重计算,同时保留计算密集型操作(FFN)的激活值。

Flash Attention:IO感知的革命

Gradient Checkpointing通过重计算减少内存使用,但根本问题仍然存在:注意力机制的计算复杂度是O(S²),这意味着注意力分数矩阵的大小随序列长度平方增长。2022年提出的Flash Attention从另一个角度解决了这个问题——优化内存访问模式。

GPU内存层次与IO瓶颈

现代GPU有复杂的内存层次:

  • HBM(高带宽内存):容量大(80GB),带宽中等(2TB/s)
  • SRAM(片上缓存):容量小(每SM约20MB),带宽极高(约20TB/s)

标准注意力实现的IO模式:

  1. 从HBM读取Q、K矩阵
  2. 计算QK^T,写入HBM(S×S矩阵)
  3. 从HBM读取QK^T,计算softmax,写入HBM
  4. 从HBM读取softmax结果和V矩阵
  5. 计算注意力输出,写入HBM

问题在于:中间的S×S矩阵需要频繁写入和读取HBM。对于长序列(如S=8192),这个矩阵可达512MB,而HBM带宽有限,成为性能瓶颈。

Flash Attention的分块策略

Flash Attention的核心思想是:不将大型中间矩阵写入HBM,而是通过分块计算在SRAM中完成所有操作。

算法流程:

  1. 将Q、K、V分块加载到SRAM
  2. 在SRAM中计算局部注意力分数
  3. 使用在线softmax算法(数值稳定性保证)
  4. 累积局部结果到全局输出
  5. 只将最终输出写入HBM

关键创新:在线softmax算法允许逐块计算softmax而不需要完整的QK^T矩阵。通过维护运行时的最大值和归一化因子,可以在分块情况下得到精确的softmax结果。

内存复杂度优化

标准注意力:O(S²) 内存(需要存储完整的注意力矩阵)

Flash Attention:O(S) 内存(只存储最终输出和少量中间变量)

对于S=16K的序列:

  • 标准注意力:约2GB内存
  • Flash Attention:约128KB内存

这不仅是数量级的改进,更使得超长序列训练成为可能。GPT-4的128K上下文窗口、Llama 3的1M上下文,都依赖于Flash Attention的实现。

Flash Attention 2与3的演进

Flash Attention 2(2023)

  • 优化并行化策略,减少warp间的同步开销
  • 更好的工作分区,提高GPU利用率
  • 在A100上实现约70%的理论峰值FLOPS

Flash Attention 3(2024)

  • 针对Hopper GPU的新特性优化
  • 利用Tensor Core和TMA的异步特性
  • 使用ping-pong调度重叠GEMM和softmax
  • 支持FP8低精度计算
  • 在H100上实现约75%的理论峰值FLOPS(FP16)或接近1.2 PFLOPS(FP8)
flowchart TB
    subgraph 标准注意力
        A1[读取Q,K] --> A2[计算QK^T<br/>写入HBM]
        A2 --> A3[读取QK^T<br/>计算softmax]
        A3 --> A4[写入HBM]
        A4 --> A5[读取softmax,V<br/>计算输出]
        A5 --> A6[写入输出]
    end
    
    subgraph Flash Attention
        B1[分块加载Q,K,V到SRAM] --> B2[SRAM内计算<br/>在线softmax]
        B2 --> B3[累积结果]
        B3 --> B4[写入最终输出]
    end
    
    A7[IO复杂度: O S²] -.-> 标准注意力
    B5[IO复杂度: O S] -.-> Flash Attention

3D并行:跨越单机限制

当模型规模增长到单张GPU无法容纳时,需要将模型本身分布到多张GPU上。3D并行是当前大规模训练的标准架构,结合了三种并行策略。

数据并行(Data Parallelism)

数据并行是最简单的策略:

  • 每个GPU持有完整的模型副本
  • 输入数据分割到各GPU
  • 独立进行前向和反向传播
  • 使用all-reduce同步梯度

优点:实现简单,扩展性好 缺点:每个GPU都需要存储完整模型,受单GPU显存限制

结合ZeRO后,数据并行可以突破单GPU显存限制,这也是ZeRO的核心价值。

张量并行(Tensor Parallelism)

张量并行在层内分割参数:

  • 将大型矩阵乘法分解到多个GPU
  • 每个GPU持有参数的一个分片
  • 通过all-reduce同步部分结果

以注意力层为例,假设有N个GPU:

  • Q、K、V投影矩阵按列分割,每个GPU负责 N 个head
  • 每个GPU独立计算部分注意力输出
  • 输出投影矩阵按行分割
  • 使用all-reduce汇总最终输出

内存效果:

  • 参数内存:降低 N 倍
  • 激活值内存:部分降低(注意力输出需要同步)

张量并行需要在每个层进行通信,因此最适合在节点内使用(NVLink高速互联)。跨节点使用时,通信开销可能超过计算收益。

流水线并行(Pipeline Parallelism)

流水线并行在层间分割参数:

  • 将模型的连续层分配到不同GPU
  • 每个GPU只持有部分层
  • 数据以"流水线"方式依次流经各GPU

以4路流水线为例,24层模型:

  • GPU 0:第1-6层
  • GPU 1:第7-12层
  • GPU 2:第13-18层
  • GPU 3:第19-24层

内存效果:

  • 参数内存:降低 N 倍
  • 梯度内存:降低 N 倍(每GPU只存储自己层的梯度)
  • 激活值内存:需要存储所有micro-batch的激活值(流水线气泡)

流水线并行的挑战是"气泡"(bubble)——GPU空闲等待的时间。使用GPipe或1F1B调度可以减少气泡比例。

3D并行的组合策略

在大规模训练中,三种并行策略通常组合使用:

典型配置

  • 张量并行(TP):节点内,利用NVLink高速互联
  • 流水线并行(PP):跨节点,容忍较高的网络延迟
  • 数据并行(DP):剩余维度,配合ZeRO进一步降低内存

并行度计算

$$DP = \frac{总GPU数}{TP \times PP}$$

内存计算

$$M_{GPU} = \frac{M_{params}}{TP \times PP} + \frac{M_{optimizer}}{N_{total}} + \frac{M_{activations}}{TP}$$

以训练一个1000亿参数模型为例,使用128个GPU(8个节点,每节点8 GPU):

  • TP = 8(节点内全张量并行)
  • PP = 4(跨节点流水线)
  • DP = 4(配合ZeRO-1)

每个GPU的静态内存:

  • 参数:约25GB / (8 × 4) ≈ 0.8GB
  • 优化器:约50GB / 128 ≈ 0.4GB
  • 梯度:约6GB / 4 ≈ 1.5GB

这远低于单GPU的显存容量,为激活值留出充足空间。

graph TB
    subgraph 3D并行架构
        A[数据并行<br/>ZeRO分片] --> B[GPU组1]
        A --> C[GPU组2]
        A --> D[GPU组3]
        A --> E[GPU组4]
        
        B --> F[流水线并行<br/>层间分割]
        F --> G[Stage 1]
        F --> H[Stage 2]
        F --> I[Stage 3]
        F --> J[Stage 4]
        
        G --> K[张量并行<br/>层内分割]
        K --> L[GPU 0-7]
    end

序列并行(Sequence Parallelism)

序列并行是张量并行的扩展,专门针对长序列场景。核心思想是:将序列维度分割到多个GPU,而非在attention层复制完整序列。

在Megatron-LM的实现中:

  • LayerNorm和Dropout操作沿序列维度分片
  • 注意力层仍然需要完整的序列(或使用Ring Attention)
  • 通过all-gather和reduce-scatter同步

内存效果:激活值内存降低 TP 倍,特别是与序列长度相关的部分。

对于超长序列(如100K tokens),序列并行是必需的技术。Meta在训练Llama 3时,就使用了序列并行来支持128K上下文窗口。

混合精度与量化:精度换空间

除了系统级优化,数值精度的选择也直接影响显存消耗。混合精度训练和量化技术提供了另一个维度的优化空间。

混合精度训练的原理

混合精度训练的核心是:不同操作使用不同精度。

高精度操作

  • 参数更新(需要FP32保证数值稳定性)
  • 损失缩放(防止梯度下溢)

低精度操作

  • 矩阵乘法(利用Tensor Core加速)
  • 前向和反向传播的主体计算

PyTorch的AMP(Automatic Mixed Precision)框架自动处理精度转换:

  • 前向传播:FP16计算
  • 损失缩放:避免梯度下溢
  • 反向传播:FP16计算
  • 参数更新:FP32

内存收益:

  • 模型权重:FP32 → FP16,节省50%
  • 激活值:自动FP16存储
  • 优化器状态:仍需FP32

BF16的优势

BF16(Brain Float 16)是Google提出的另一种16位浮点格式:

格式 符号位 指数位 尾数位 动态范围
FP16 1 5 10 有限
BF16 1 8 7 与FP32相同
FP32 1 8 23 最大

BF16的优势在于动态范围与FP32相同,不需要损失缩放。这使得训练更加稳定,在Ampere及更新架构GPU上成为首选。

训练时量化

更激进的策略是在训练时就使用低精度表示:

FP8训练

  • Hopper GPU原生支持FP8计算
  • 需要仔细处理量化误差
  • Flash Attention 3的incoherent processing技术可降低量化误差2.6倍

INT8优化器

  • 将Adam的动量和方差量化为INT8
  • bitsandbytes库提供了实现
  • 优化器状态从12字节/参数降低到6字节/参数

内存对比:

  • 标准FP32训练:20字节/参数
  • 混合精度FP16:16字节/参数
  • 混合精度+INT8优化器:10字节/参数
  • 全FP8训练:约5字节/参数(实验性)

技术演进的脉络

回顾过去五年,显存优化技术的发展呈现出清晰的脉络:

2019年及之前

  • 混合精度训练成为标配
  • 模型并行是大规模训练的主要手段
  • 激活重计算开始被关注

2020年

  • ZeRO论文发表,分片优化成为新范式
  • GPT-3的训练证明了大规模分布式训练的可行性

2021年

  • ZeRO-Offload将CPU内存纳入显存池
  • Megatron-LM成熟,3D并行成为标准架构

2022年

  • Flash Attention革命性地优化了注意力计算
  • 选择性激活重计算显著降低重计算开销
  • 序列并行支持超长上下文

2023-2024年

  • Flash Attention 2/3持续优化
  • FP8训练开始实用化
  • MoE(Mixture of Experts)带来新的并行策略

这些技术不是孤立的,而是相互补充、组合使用。一个典型的现代大模型训练配置会同时使用:

  • 混合精度训练(基础)
  • ZeRO-3(分布式优化器)
  • Gradient Checkpointing(激活值优化)
  • Flash Attention(注意力优化)
  • 3D并行(跨GPU扩展)

实践指南:显存估算与配置

理解了各项技术,如何在实践中应用?以下是显存估算和配置的基本方法。

显存估算公式

静态内存(混合精度+Adam):

$$M_{static} = 16 \times P + \frac{12 \times P}{N_{ZeRO}}$$

其中P是参数量,N_{ZeRO}是ZeRO分片的GPU数。

激活值内存

$$M_{activation} = S \times B \times H \times L \times (10 + \frac{24}{TP} + \frac{5 \times a \times S}{H \times TP})$$

其中TP是张量并行度,a是注意力头数。使用选择性重计算后,最后一项可以忽略。

实例分析

以训练一个70亿参数模型为例:

模型配置

  • 参数量:7B
  • 层数:32
  • 隐藏维度:4096
  • 注意力头:32

静态内存(单GPU,无优化):

  • 权重(FP16+FP32):6 × 7B = 42GB
  • 梯度(FP16):2 × 7B = 14GB
  • 优化器状态:12 × 7B = 84GB
  • 总计:约140GB

使用ZeRO-3,64 GPU

  • 每GPU参数:42GB / 64 ≈ 0.7GB
  • 每GPU梯度:14GB / 64 ≈ 0.2GB
  • 每GPU优化器:84GB / 64 ≈ 1.3GB
  • 静态总计:约2.2GB/GPU

激活值内存(batch=1, seq=2048, TP=8):

  • 无重计算:约28GB/GPU
  • 选择性重计算:约6GB/GPU

最终每GPU显存需求:约8-10GB(包含框架开销)

这意味着使用ZeRO-3、选择性重计算和适当的并行配置,可以在A100 80GB GPU上以较大的批次大小训练70亿参数模型。

配置决策树

flowchart TD
    A[开始配置] --> B{模型是否单GPU可容纳?}
    B -->|是| C[数据并行+ZeRO-1]
    B -->|否| D{是否跨节点?}
    
    D -->|否| E[张量并行<br/>节点内GPU全用]
    D -->|是| F[张量并行+流水线并行]
    
    C --> G{激活值是否超限?}
    E --> G
    F --> G
    
    G -->|是| H[启用选择性重计算]
    G -->|否| I[继续评估]
    
    H --> J{序列是否超长?}
    I --> J
    
    J -->|是| K[启用序列并行]
    J -->|否| L[最终配置]
    
    K --> L

未来展望

显存优化技术的发展远未结束。几个值得关注的趋势:

硬件演进

  • HBM4将提供2TB/s+带宽
  • 大容量GPU(如H200 141GB)减少分片需求
  • 专用AI芯片可能提供不同的内存架构

算法创新

  • 线性注意力机制尝试从根本上解决O(S²)复杂度
  • 状态空间模型(如Mamba)提供不同的序列建模范式
  • 稀疏注意力减少计算量和内存使用

系统优化

  • 更智能的自动调优系统
  • 异构内存管理(GPU+CPU+SSD)
  • 容错训练支持(应对大规模集群的硬件故障)

显存优化已经成为大模型训练的核心课题。从ZeRO的分片策略到Flash Attention的IO优化,每一步创新都推动着模型规模的边界。理解这些技术的原理和权衡,是每个大模型从业者的必修课。


参考文献

  1. Rajbhandari, S., Rasley, J., Ruwase, O., & He, Y. (2020). ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. SC20.
  2. Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training Deep Nets with Sublinear Memory Cost. arXiv:1604.06174.
  3. Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
  4. Korthikanti, V., et al. (2023). Reducing Activation Recomputation in Large Transformer Models. MLSys 2023.
  5. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691.
  6. Dao, T., et al. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08608.
  7. Ren, J., et al. (2021). ZeRO-Offload: Democratizing Billion-Scale Model Training. USENIX ATC 2021.
  8. Shoeybi, M., et al. (2019). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv:1909.08053.
  9. Narayanan, D., et al. (2021). Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM. SC21.
  10. Anthony, Q., Biderman, S., & Schoelkopf, H. (2023). Transformer Math 101. EleutherAI Blog.