训练大模型时,显存总是第一个瓶颈。以一个拥有70亿参数的LLaMA 7B模型为例,仅存储FP32权重就需要约28GB显存。再加上梯度、优化器状态和中间激活值,即使是一张A100 80GB显卡也可能捉襟见肘。

于是开发者们开始寻找解决方案。一个自然的想法是:能不能用更低的精度存储数据?毕竟神经网络的权重大多是-1到1之间的小数,用32位浮点数似乎有些浪费。

这个直觉在2017年被NVIDIA和Baidu的研究团队验证。他们发现,只要处理得当,用16位浮点数训练出的模型,精度几乎不逊色于32位——甚至有时候还更好。这项技术被称为混合精度训练(Mixed Precision Training),它不仅将显存占用减半,还能让训练速度提升数倍。

但问题来了:既然16位精度够用,为什么不直接全部用16位,而要坚持"混合"?为什么有些模型用FP16训练会崩,换成BF16就没问题?Loss Scaling又是怎么回事?

这些问题背后,藏着浮点数表示、数值稳定性和硬件架构设计的深层逻辑。

IEEE 754:浮点数的三位一体

要理解混合精度训练的困境,先得理解浮点数是如何存储的。

IEEE 754标准定义了浮点数的存储格式,核心思想是科学计数法的变种:任何浮点数都可以表示为:

$$(-1)^{sign} \times 1.mantissa \times 2^{exponent - bias}$$

一个浮点数由三部分组成:

  • 符号位(Sign):决定正负
  • 指数位(Exponent):决定数值范围
  • 尾数位(Mantissa/Fraction):决定精度

这三者的位宽分配,决定了这个浮点数格式的"性格"。

FP32:标准单精度

FP32(Single Precision)是最常用的浮点格式:

block-beta
    columns 3
    block:fp32:1
        sign["符号位 (1 bit)"]
    end
    block:exp:1
        exponent["指数位 (8 bits)"]
    end
    block:man:1
        mantissa["尾数位 (23 bits)"]
    end
  • 指数位8位,偏置值127
  • 尾数位23位
  • 动态范围:$\approx 10^{-38}$ 到 $\approx 10^{38}$
  • 精度:约7位有效数字

这个精度对于科学计算足够,但对于深度学习来说,似乎有些奢侈。

FP16:IEEE标准的半精度

FP16(Half Precision)是IEEE 754标准定义的16位浮点格式:

block-beta
    columns 3
    block:fp16:1
        sign["符号位 (1 bit)"]
    end
    block:exp2:1
        exponent["指数位 (5 bits)"]
    end
    block:man2:1
        mantissa["尾数位 (10 bits)"]
    end
  • 指数位5位,偏置值15
  • 尾数位10位
  • 动态范围:$\approx 6 \times 10^{-5}$ 到 $65504$
  • 精度:约3位有效数字

问题出现了:FP16的最大值只有65504。超过这个值就会变成inf(无穷大),这在神经网络训练中是致命的——梯度一旦溢出,权重更新就会失败。

更麻烦的是下限。FP16能表示的最小正数是$6 \times 10^{-5}$,比这更小的值会被截断为0。而神经网络训练中的梯度,常常比这个值还小。

BF16:Google的工程妥协

BF16(Brain Float 16)是Google Brain团队设计的一种16位格式,它做了一个关键的权衡:

block-beta
    columns 3
    block:bf16:1
        sign["符号位 (1 bit)"]
    end
    block:exp3:1
        exponent["指数位 (8 bits)"]
    end
    block:man3:1
        mantissa["尾数位 (7 bits)"]
    end
  • 指数位8位(与FP32相同!)
  • 尾数位7位
  • 动态范围:$\approx 10^{-38}$ 到 $\approx 3.4 \times 10^{38}$(与FP32几乎相同!)
  • 精度:约2-3位有效数字

BF16的精髓在于:保留了FP32的动态范围,牺牲了精度

这为什么对深度学习很重要?

神经网络有一个神奇的特性:它对精度的要求远低于对动态范围的要求。权重的微小变化(比如从0.1234567变成0.12)通常不会显著影响模型性能,但如果梯度因为数值范围问题变成0或inf,训练就会失败。

graph LR
    subgraph FP32
        A1["范围: 10^-38 ~ 10^38"]
        A2["精度: ~7位"]
    end
    subgraph FP16
        B1["范围: 10^-5 ~ 65504"]
        B2["精度: ~3位"]
    end
    subgraph BF16
        C1["范围: 10^-38 ~ 10^38"]
        C2["精度: ~2位"]
    end
    FP32 -->|"截断尾数"| BF16
    FP32 -->|"重新分配"| FP16
    style B1 fill:#ffcccc
    style C1 fill:#ccffcc

图中红色标注的是FP16的致命缺陷——动态范围太窄;绿色标注的是BF16的优势——动态范围与FP32相同。

混合精度训练:为什么不能全用16位

既然BF16的动态范围与FP32相同,为什么还需要"混合"精度?

答案藏在权重更新的细节里。

假设学习率是$10^{-4}$,梯度是$10^{-3}$,那么权重更新量是$10^{-7}$。如果权重本身在1附近,用FP16或BF16存储权重,这个更新量相对于权重来说是可以忽略的精度误差。

但问题在于:权重的更新需要累加很多次梯度

权重更新的精度陷阱

假设权重$w = 1.0$,学习率$\eta = 0.0001$,梯度$g = 0.001$。一次更新的变化量:

$$\Delta w = \eta \times g = 0.0000001$$

如果用FP16存储权重,权重变成了$1.0000001$。但FP16只有约3位有效数字,这个数很可能被四舍五入回$1.0$。

于是,梯度更新"消失"了。

这不是理论上的担忧。2017年的原始论文《Mixed Precision Training》中,研究者发现某些网络在纯FP16训练时完全不收敛——不是因为梯度太小或太大,而是因为权重更新被精度误差吞噬。

解决方案:FP32主权重副本

这就是混合精度训练的第一项核心技术:维护一个FP32的"主权重副本"(Master Weights)

flowchart TD
    subgraph 训练迭代
        A[FP32主权重] -->|转换为FP16| B[FP16权重副本]
        B --> C[前向传播<br/>FP16计算]
        C --> D[计算损失]
        D --> E[反向传播<br/>FP16梯度]
        E --> F[FP32梯度累积]
        F --> G[更新FP32主权重]
        G --> A
    end

具体流程是:

  1. 前向传播前:将FP32主权重转换为FP16副本
  2. 前向和反向传播:使用FP16副本进行计算(节省显存和计算)
  3. 权重更新:梯度先转换为FP32,然后更新FP32主权重
  4. 下一轮迭代:重复上述过程

这样,FP16用于加速计算,FP32用于保证权重更新的精度。显存占用增加了约50%(因为需要存储额外的FP32权重),但仍然比纯FP32节省了大量显存。

但这个方案还不够——梯度本身还有问题。

梯度下溢:Loss Scaling的数学原理

FP16的另一个致命问题是梯度下溢

神经网络训练中,梯度值通常很小,尤其是在深层网络中。FP16能表示的最小正数约为$6 \times 10^{-5}$,而很多梯度值可能只有$10^{-8}$甚至更小。

这些小梯度在FP16中会被截断为0,导致对应的权重永远不会更新。

损失缩放的直觉

解决思路很直接:把梯度放大

但直接放大梯度是不可行的,因为梯度是在反向传播中自动计算的。更聪明的方法是:放大损失函数

假设原始损失是$L$,缩放因子是$S$,我们计算:

$$L' = S \times L$$

根据链式法则,梯度也会被放大同样的倍数:

$$\frac{\partial L'}{\partial w} = S \times \frac{\partial L}{\partial w}$$

这样,原本可能下溢的梯度就被放大到了FP16能表示的范围。

但在更新权重之前,必须把梯度缩放回去:

$$\Delta w = \eta \times \frac{1}{S} \times \frac{\partial L'}{\partial w}$$

静态vs动态缩放

损失缩放的核心问题是:缩放因子选多大?

静态损失缩放使用固定的缩放因子,比如$2^{16} = 65536$。这简单直接,但有风险:如果梯度本来就大,乘以65536后会溢出;如果梯度特别小,可能还是不够。

动态损失缩放则更智能:从较大的缩放因子开始,训练过程中动态调整。如果检测到梯度溢出(出现infNaN),就减小缩放因子并跳过这次更新;如果连续多次没有溢出,就尝试增大缩放因子。

PyTorch的GradScaler使用的就是这个策略:

# PyTorch动态损失缩放的核心逻辑
scaler = torch.amp.GradScaler('cuda')

for data, target in dataloader:
    optimizer.zero_grad()
    
    with torch.amp.autocast("cuda"):
        output = model(data)
        loss = criterion(output, target)
    
    # 缩放损失并反向传播
    scaler.scale(loss).backward()
    
    # 缩放梯度并更新权重
    scaler.step(optimizer)
    scaler.update()

scaler.update()内部实现了动态调整逻辑:检查梯度是否溢出,根据情况调整缩放因子。

BF16为何不需要Loss Scaling

现在可以回答之前的问题:为什么BF16不需要损失缩放?

根本原因在于动态范围。BF16有8位指数位,与FP32相同,能表示的最小正数约为$10^{-38}$。而神经网络训练中的梯度,即使在最深层,也极少小于$10^{-30}$。

换句话说,BF16的动态范围足以覆盖神经网络训练中所有可能出现的数据,包括梯度、权重、激活值等。

但BF16也不是没有代价。它的尾数位只有7位(FP32有23位),精度约为2-3位有效数字。这意味着:

  • 权重值$0.123456$可能被存储为$0.123$
  • 两个接近的数相减可能损失精度
  • 某些对精度敏感的操作可能失败

在实践中,BF16的精度损失通常不影响模型训练,因为神经网络的训练过程本身就是在近似优化,小幅的精度误差会被随机性吸收。

graph TD
    subgraph FP16训练流程
        A1[前向传播] --> B1[计算损失]
        B1 --> C1[损失缩放]
        C1 --> D1[反向传播]
        D1 --> E1[梯度检查<br/>是否溢出?]
        E1 -->|是| F1[减小缩放因子<br/>跳过更新]
        E1 -->|否| G1[梯度反缩放]
        G1 --> H1[更新FP32主权重]
    end
    
    subgraph BF16训练流程
        A2[前向传播] --> B2[计算损失]
        B2 --> C2[反向传播<br/>无需缩放]
        C2 --> D2[更新FP32主权重]
    end
    
    style C1 fill:#ffffcc
    style E1 fill:#ffffcc
    style F1 fill:#ffcccc
    style C2 fill:#ccffcc

混合精度的第三支柱:FP16累加

除了FP32主权重和损失缩放,混合精度训练还有第三个关键技术:FP16矩阵乘法的累加精度

矩阵乘法是深度学习中最频繁的操作,形式为$C = A \times B$。在硬件层面,这涉及大量乘法和加法运算。

问题在于:FP16的精度有限,多个小数相加会累积误差。假设每次加法产生$10^{-4}$的误差,累加1000次后误差就变成了$10^{-1}$,可能影响最终结果。

现代GPU的Tensor Core通过一个巧妙的设计解决这个问题:虽然输入和输出都是FP16,但累加过程使用FP32精度

# Tensor Core的计算模式
for i in range(N):
    accumulator_fp32 += fp16_a[i] * fp16_b[i]  # 累加在FP32中进行
output_fp16 = fp32_to_fp16(accumulator_fp32)   # 最终转换回FP16

这保证了在矩阵乘法的核心计算中,精度损失被最小化。这也是为什么Tensor Core能比传统CUDA Core快数倍的原因——它不仅仅是简单地并行化,还优化了数值精度。

实践指南:何时选择FP16,何时选择BF16

在实际项目中,精度格式的选择取决于硬件、模型和任务。

硬件支持

硬件架构 FP16 BF16 Tensor Core
Volta (V100) FP16
Turing (T4) FP16
Ampere (A100) FP16/BF16
Hopper (H100) FP16/BF16/FP8

V100和T4不支持BF16,只能使用FP16。A100和H100两者都支持,BF16通常是更好的选择。

模型特性

优先选择BF16的场景:

  • 预训练模型已经是BF16格式
  • 模型容易产生梯度下溢(深层网络、长序列)
  • 不想调损失缩放的参数

必须使用FP16的场景:

  • 硬件不支持BF16
  • 与旧代码兼容

需要注意的场景:

  • 模型中有大量exp()log()操作(容易溢出)
  • 自定义CUDA算子(可能不支持低精度)

常见问题排查

问题1:训练开始就出现NaN

可能原因:损失缩放因子太大,导致梯度溢出。

解决方案:

# 从较小的缩放因子开始
scaler = torch.amp.GradScaler('cuda', init_scale=2**8)

问题2:训练中期突然出现NaN

可能原因:

  • 学习率太大
  • 数据中有异常值
  • 模型架构有问题(如除以接近0的数)

调试方法:

# 在backward之后检查梯度
scaler.scale(loss).backward()
print(f"Gradient max: {model.fc1.weight.grad.max()}")
print(f"Gradient min: {model.fc1.weight.grad.min()}")
print(f"Gradient has nan: {torch.isnan(model.fc1.weight.grad).any()}")

问题3:训练收敛但精度下降

可能原因:精度损失累积。

解决方案:

  • 尝试BF16(如果硬件支持)
  • 增加FP32主权重的使用
  • 检查是否有精度敏感的操作被错误地转换为低精度

硬件演进:从Volta到Hopper

混合精度训练的普及,与GPU硬件的演进密不可分。

Volta架构(2017):Tensor Core的诞生

V100是第一个引入Tensor Core的消费级GPU。Tensor Core专门为矩阵乘法设计,能在单个时钟周期内完成$4 \times 4$矩阵的乘加运算。

关键指标:

  • FP16 Tensor Core性能:125 TFLOPS
  • FP32 CUDA Core性能:15.7 TFLOPS
  • 加速比:约8倍

Volta时代,混合精度训练主要是FP16+FP32的组合,需要损失缩放。

Ampere架构(2020):BF16的支持

A100引入了对BF16的硬件支持。这意味着:

  • 前向和反向传播都可以使用BF16
  • 不再需要损失缩放
  • 显存使用减半

关键指标:

  • BF16 Tensor Core性能:312 TFLOPS
  • FP16 Tensor Core性能:312 TFLOPS
  • 相比V100:FP16性能提升2.5倍

Hopper架构(2022):FP8时代

H100引入了FP8(8位浮点数),进一步推动精度边界的探索。FP8有两种变体:

  • E4M3:4位指数,3位尾数(更适合前向传播)
  • E5M2:5位指数,2位尾数(更适合反向传播)

关键指标:

  • FP8 Tensor Core性能:1979 TFLOPS
  • 相比A100 BF16:性能提升约6倍

FP8的训练更复杂,需要更精细的缩放策略,但对于超大模型(如GPT-4级别)的训练,这是突破显存瓶颈的关键技术。

代码实战:PyTorch AMP的正确姿势

最后,给出一份生产环境可用的混合精度训练模板:

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, train_loader, optimizer, criterion, epochs, use_bf16=False):
    """
    混合精度训练模板
    
    Args:
        model: 神经网络模型
        train_loader: 训练数据加载器
        optimizer: 优化器
        criterion: 损失函数
        epochs: 训练轮数
        use_bf16: 是否使用BF16(需要Ampere或更新的GPU)
    """
    device = next(model.parameters()).device
    
    # BF16不需要GradScaler
    scaler = None if use_bf16 else GradScaler()
    
    # 选择精度格式
    dtype = torch.bfloat16 if use_bf16 else torch.float16
    
    model.train()
    
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            
            # 自动混合精度上下文
            with autocast(device_type='cuda', dtype=dtype, enabled=True):
                output = model(data)
                loss = criterion(output, target)
            
            # 反向传播
            if use_bf16:
                # BF16:直接反向传播
                loss.backward()
                optimizer.step()
            else:
                # FP16:使用梯度缩放
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')

# 使用示例
model = MyLargeModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# 选择精度格式
use_bf16 = torch.cuda.is_bf16_supported()  # 自动检测硬件是否支持BF16

train_with_amp(model, train_loader, optimizer, criterion, epochs=10, use_bf16=use_bf16)

关键要点:

  1. 自动检测BF16支持torch.cuda.is_bf16_supported()可以检测当前GPU是否支持BF16
  2. BF16不需要GradScaler:如果使用BF16,可以省略梯度缩放的相关代码
  3. autocast是关键:它会自动判断哪些操作应该使用低精度,哪些应该保持FP32

权衡与边界:混合精度不是万能药

混合精度训练虽然强大,但并非万能。以下是它的边界和权衡:

适用场景

  • 大模型训练(参数量>1亿)
  • 显存受限的场景
  • 对训练速度有要求的场景

不适用场景

  • 对数值精度有严格要求的科学计算
  • 小模型训练(混合精度的收益不明显)
  • 模型本身存在数值不稳定问题(混合精度可能加剧)

代价

  • 实现复杂度增加:需要正确配置autocast和GradScaler
  • 调试难度增加:NaN/Inf问题更难排查
  • 部分操作不支持:某些自定义算子可能不支持低精度

小结

混合精度训练的核心洞察是:神经网络对动态范围敏感,对精度宽容

FP16有足够的精度,但动态范围太窄;BF16保留了FP32的动态范围,牺牲了一些精度。两种格式都需要维护FP32主权重副本来保证权重更新的精度。

FP16需要损失缩放来解决梯度下溢问题,BF16则不需要。选择哪种格式,取决于硬件支持和模型特性。

从Volta到Hopper,GPU架构的演进不断推动着精度边界的探索。FP8的出现意味着我们正在接近数值表示的极限——8位浮点数已经是能想象的最小浮点格式了。

对于今天的从业者来说,混合精度训练已经是大模型训练的标准实践。理解其背后的原理,能帮助你在遇到问题时更快地定位和解决。


参考文献

  1. Micikevicius, P., et al. (2017). Mixed Precision Training. arXiv:1710.03740.
  2. NVIDIA Documentation. Train With Mixed Precision. https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/
  3. PyTorch Documentation. Automatic Mixed Precision package - torch.amp. https://pytorch.org/docs/stable/amp.html
  4. Wikipedia. bfloat16 floating-point format. https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
  5. Wang, S. & Kanwar, P. (2019). BFloat16: The secret to high performance on Cloud TPUs. Google Cloud Blog.
  6. NVIDIA Technical Blog. NVIDIA Hopper Architecture In-Depth. (2022)
  7. NVIDIA Documentation. Floating-Point 8: An Introduction to Efficient, Lower-Precision AI Training. (2025)
  8. Higham, N. (2018). Half Precision Arithmetic: fp16 Versus bfloat16.
  9. IJCAI 2020. Reducing Underflow in Mixed Precision Training by Gradient Scaling.