你正在训练一个Transformer模型,Loss曲线稳定下降,一切看起来都很顺利。然后你决定启用混合精度训练来加速——只需一行代码.half()。100步之后:Loss: NaN。训练彻底崩溃。

模型没有忘记如何学习,数学公式也没有改变。但GPU能表示的数字范围变了。在深度学习中,这个差异比你想象的更重要。

这不是一个罕见的问题。GPT-3、BERT以及每一个现代Transformer模型都曾面对这个挑战。让我们从头理解这个问题的本质,以及为什么一行简单的x - max(x)能够拯救你的训练。

Transformer模型架构

浮点数:计算机如何存储"实数"

当你写下weight = 2.5时,GPU存储的不是精确的2.5,而是一个二进制近似。IEEE 754标准定义了浮点数的存储方式,它使用三个部分来表示一个数:

$$v = (-1)^s \times 2^{e-bias} \times (1 + fraction)$$

其中s是符号位,e是指数位,fraction是尾数(也叫有效数字)。这种表示方式像科学计数法,但有一个关键限制:位数是固定的。

FP32(单精度浮点数) 使用32位存储:

  • 1位符号
  • 8位指数
  • 23位尾数

这给出了约$\pm 3.4 \times 10^{38}$的范围和约7位十进制精度。

FP16(半精度浮点数) 使用16位存储:

  • 1位符号
  • 5位指数
  • 10位尾数

范围缩减到$\pm 65,504$,精度只有约3位十进制数字。

浮点数表示示例

这里有一个关键洞察:FP16的动态范围严重不足。最大值65,504意味着当输入$x > 11$时,$e^x$就会溢出——因为$e^{11} \approx 59,874$,已经接近FP16的表示上限。

import numpy as np

# FP16的噩梦
x = np.array([10.0], dtype=np.float16)
print(np.exp(x))  # 正常: [22026.]

x = np.array([12.0], dtype=np.float16)
print(np.exp(x))  # 溢出: [inf]

这就是为什么在混合精度训练中,看似正常的计算会突然崩溃。

Softmax:一个对数值稳定性极度敏感的函数

Softmax函数将一组实数转换为概率分布:

$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$$

这个定义在数学上完美无瑕,但在数值计算中存在两个致命陷阱:

数值计算中的问题可视化

溢出陷阱

当输入值很大时,指数运算会迅速超过浮点数的表示范围:

def unstable_softmax(x):
    return np.exp(x) / np.sum(np.exp(x))

# 看似正常的输入
x = np.array([1000.0, 1000.0, 1000.0])
print(unstable_softmax(x))
# 输出: [nan, nan, nan]
# 原因: exp(1000) = inf, inf/inf = nan

在FP16下,这个问题更加严重。Transformer的注意力分数(QK^T的结果)可能产生很大的值,特别是对于长序列。当序列长度增加时,点积结果的方差也随之增大,很容易产生超过11的值。

下溢陷阱

当输入值是很小的负数时,指数运算会得到接近零的结果:

x = np.array([-1000.0, -1000.0, -1000.0])
print(np.exp(x))
# 输出: [0., 0., 0.]

下溢本身不会立即导致崩溃,但会使分母变成零,或者在后续计算中丢失所有精度。

Log-Sum-Exp技巧:一个优雅的数学救星

核心观察:在Softmax的分子分母中,我们可以乘以任意常数而不改变结果。这个看似平凡的观察,正是解决数值稳定性的钥匙。

数值稳定性解决方案

$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}} = \frac{e^{x_i} \cdot C}{\sum_j e^{x_j} \cdot C} = \frac{e^{x_i + \ln C}}{\sum_j e^{x_j + \ln C}}$$

如果我们选择$\ln C = -\max(x)$,那么:

$$\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}$$

这个变换的精妙之处在于:

  1. 防止溢出:减去最大值后,所有指数的输入$x_i - \max(x) \leq 0$,因此$e^{x_i - \max(x)} \leq 1$。永远不会溢出。

  2. 保证至少一个非零值:当$x_i = \max(x)$时,$e^0 = 1$。分母至少为1,不会出现除以零的错误。

  3. 数学等价:结果与原始Softmax完全相同,只是计算过程更稳定。

def stable_softmax(x):
    x_shifted = x - np.max(x)
    exp_x = np.exp(x_shifted)
    return exp_x / np.sum(exp_x)

# 现在可以处理任意大的输入
x = np.array([1000.0, 1000.0, 1000.0])
print(stable_softmax(x))
# 输出: [0.333..., 0.333..., 0.333...]

x = np.array([1.0, 2.0, 3.0])
print(stable_softmax(x))
# 输出: [0.0900..., 0.2447..., 0.6652...]

Log-Sum-Exp函数

在实践中,我们经常需要计算Log-Softmax(Softmax的对数),特别是在交叉熵损失中。直接计算log(softmax(x))会在概率接近零时产生log(0) = -inf的问题。

推导过程:

$$\log(\text{softmax}(x_i)) = \log\left(\frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}\right)$$$$= (x_i - \max(x)) - \log\left(\sum_j e^{x_j - \max(x)}\right)$$

这里$\log\sum e^{x_j - \max(x)}$就是著名的Log-Sum-Exp (LSE) 函数。由于$\sum_j e^{x_j - \max(x)} \geq 1$(至少有一个$e^0 = 1$),这个对数永远不会出现$\log(0)$的问题。

def log_softmax(x):
    x_max = np.max(x)
    return x - x_max - np.log(np.sum(np.exp(x - x_max)))

def cross_entropy_loss(logits, target):
    """数值稳定的交叉熵损失"""
    return -log_softmax(logits)[target]

PyTorch的F.log_softmaxF.cross_entropy底层就是这样实现的,它们将Softmax和对数操作融合在一起,既提高了数值稳定性,又减少了计算开销。

混合精度训练:当FP16遇上梯度下溢

启用了FP16混合精度训练后,新的问题出现了。即使Softmax本身是稳定的,梯度可能在其他地方崩溃。

混合精度训练示意图

梯度下溢问题

训练过程中,梯度的大小往往随着反向传播逐层减小:

Layer 1: gradient = 0.0001
Layer 2: gradient = 0.00003
Layer 3: gradient = 0.000007
Layer 4: gradient = 0.0000001

FP16能表示的最小正数约为$6 \times 10^{-8}$(归一化数)。第四层的梯度$1 \times 10^{-7}$会被舍入为零。

一项2018年的研究发现,在Multibox SSD检测器网络上,31%的梯度在转换为FP16时变成零,只有5.3%保持非零。训练完全发散。

Loss Scaling:强行放大梯度

解决方案非常直观:如果梯度太小,就在存储到FP16之前先放大。

静态Loss Scaling

# 手动实现loss scaling
scale = 1024.0
loss = model(input)
scaled_loss = loss * scale
scaled_loss.backward()

# 梯度现在是原来的1024倍,可以安全存储在FP16中
for param in model.parameters():
    param.grad = param.grad / scale  # 恢复原始大小

optimizer.step()

动态Loss Scaling(PyTorch的GradScaler):

静态缩放的问题是:缩放因子太小则仍有下溢,太大则可能溢出。动态缩放自动调整:

  1. 从较大的缩放因子开始(如65536)
  2. 每次反向传播后检查梯度是否有inf/nan
  3. 如果有:跳过本次更新,缩放因子减半
  4. 如果连续N次(如2000次)没有问题:缩放因子翻倍
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

for batch in dataloader:
    with autocast():  # 自动选择FP16计算
        loss = model(batch)
    
    scaler.scale(loss).backward()  # 自动缩放
    scaler.step(optimizer)         # 检查梯度,必要时跳过更新
    scaler.update()                # 调整缩放因子
    
    optimizer.zero_grad()

Master Weights:累积小更新的智慧

即使解决了梯度下溢,还有一个隐蔽的问题:权重更新的精度损失

FP16的精度约为3位十进制数字。当权重接近1时,小于0.001的更新会被舍入掉:

weight_fp16 = 1.0
update = 0.0001
weight_fp16 += update  # 结果仍然是1.0!

解决方案是维护FP32的"主权重"副本。每次更新时,在FP32精度下累积更新,然后截断到FP16用于前向传播:

# 伪代码展示原理
master_weight_fp32 = model.weight.clone().float()

# 前向传播使用FP16
output = model(input.half())

# 反向传播
loss.backward()

# 在FP32中累积更新
update = learning_rate * model.weight.grad.float()
master_weight_fp32 += update  # 精度完整保留

# 复制回FP16
model.weight.data = master_weight_fp32.half()

这就是为什么混合精度训练框架(如PyTorch AMP)在启用时,优化器内部自动维护FP32主权重。

BFloat16:用精度换范围的明智选择

FP16的根本问题是5位指数带来的窄动态范围。Google提出了一个大胆的想法:保留FP32的指数位,只减少尾数

BFloat16 (Brain Float 16) 格式:

  • 1位符号
  • 8位指数(与FP32相同)
  • 7位尾数
格式 指数位 尾数位 动态范围 精度
FP32 8 23 $\pm 3.4 \times 10^{38}$ ~7位十进制
FP16 5 10 $\pm 65,504$ ~3位十进制
BF16 8 7 $\pm 3.4 \times 10^{38}$ ~2位十进制

BF16拥有与FP32相同的动态范围,这意味着:

  • Softmax计算不需要特殊的溢出处理
  • 梯度下溢极少发生
  • 不需要Loss Scaling
# BF16训练:简单直接
with autocast(dtype=torch.bfloat16):
    loss = model(batch)
loss.backward()
optimizer.step()  # 没有GradScaler,没有额外技巧

这正是为什么A100/H100等新一代GPU和TPU原生支持BF16的原因。对于大模型训练,BF16是比FP16更好的默认选择——它"开箱即用"。

Flash Attention中的在线Softmax

Flash Attention是一种内存高效的注意力实现,它通过将注意力计算分块处理来减少内存访问。但这带来了一个挑战:如何在分块计算中保持Softmax的数值稳定性?

Flash Attention架构

标准Softmax需要三次遍历:

  1. 找到最大值
  2. 计算指数和
  3. 归一化

而Flash Attention使用在线Softmax算法,将前两次遍历合并:

def online_softmax(x):
    """在线Softmax:一次遍历计算max和sum"""
    n = len(x)
    
    # 初始化:第一个元素
    m = x[0]        # 当前最大值
    d = 1.0         # 当前归一化因子
    
    for i in range(1, n):
        if x[i] > m:
            # 发现新的最大值,需要调整之前的累积和
            d = d * np.exp(m - x[i]) + 1.0
            m = x[i]
        else:
            d += np.exp(x[i] - m)
    
    # 最终归一化
    return np.exp(x - m) / d

关键洞察:当我们发现新的最大值$m_{new}$时,之前计算的指数和需要重新归一化。所有之前累加的项都要乘以$e^{m_{old} - m_{new}}$来补偿基线的变化。

这个算法使得Flash Attention能够逐块处理注意力矩阵,同时在SRAM中保持数值稳定性,避免了将整个注意力矩阵写入高延迟的HBM。

实践指南:诊断和解决数值问题

症状1:训练开始后Loss变成NaN

检查梯度范数:

total_norm = 0.0
for p in model.parameters():
    if p.grad is not None:
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
print(f"Gradient norm: {total_norm:.4f}")
  • 健康值:0.1-10
  • 爆炸(>100):添加梯度裁剪
  • 消失(<0.001):检查学习率或网络结构

症状2:Loss Scale持续下降

print(f"Current loss scale: {scaler.get_scale()}")

如果缩放因子降到1.0并保持:

  • 考虑切换到BF16
  • 检查是否存在数值不稳定的层(如极大激活值)

症状3:特定层梯度为零

for name, param in model.named_parameters():
    if param.grad is not None:
        grad_mean = param.grad.abs().mean().item()
        if grad_mean < 1e-10:
            print(f"Zero gradient in {name}")

最佳实践清单

优先使用BF16(如果硬件支持):

  • A100/H100 GPU
  • TPU
  • 不需要Loss Scaling

FP16训练的正确方式

from torch.cuda.amp import GradScaler, autocast

model = Model().cuda()
optimizer = torch.optim.AdamW(model.parameters())
scaler = GradScaler()

for batch in dataloader:
    with autocast():
        output = model(batch)
        loss = criterion(output, target)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

避免手动.half()转换

# 错误:直接转FP16会导致训练崩溃
model = Model().half()  

# 正确:使用autocast自动管理精度
with autocast():
    output = model(input)

敏感操作保持FP32: PyTorch的autocast自动处理以下操作:

  • Layer Normalization
  • Batch Normalization
  • Softmax
  • Cross Entropy Loss

如果手动实现,确保这些操作在FP32精度下执行。

从数学到工程的完整链条

数值稳定性问题从浮点数的物理表示开始,贯穿Softmax的数学定义、梯度计算的精度损失、混合精度训练的权衡,直到Flash Attention的分块算法设计。

每一步都不是孤立的优化技巧,而是对"计算机如何表示数字"这一基础约束的系统性回应:

  1. Log-Sum-Exp:利用Softmax的数学性质(平移不变性)消除溢出风险
  2. Loss Scaling:用链式法则将梯度临时放大到可表示范围
  3. Master Weights:在高精度空间累积微小的权重更新
  4. BF16:从格式设计层面解决动态范围不足的问题
  5. Online Softmax:在分块计算中保持数值稳定性的增量算法

理解这些技术背后的数学原理,才能在面对新的训练崩溃时,系统性地定位和解决问题。当你的模型Loss变成NaN时,不再是无助的调试,而是基于浮点数表示限制的理性诊断。


参考资料

  1. Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press. Chapter 4: Numerical Computation.

  2. NVIDIA. (2023). Mixed Precision Training Documentation. NVIDIA Deep Learning Performance.

  3. Wang, S., & Kanwar, P. (2019). BFloat16: The Secret Behind High Performance on Cloud TPUs. Google Cloud Blog.

  4. Dao, T., et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.

  5. Micikevicius, P., et al. (2018). Mixed Precision Training. ICLR 2018.

  6. Dive into Deep Learning (D2L). Chapter 4.5: Numerical Stability and Initialization.

  7. PyTorch Documentation. Automatic Mixed Precision (AMP) Package.

  8. Kahan, W. (1996). IEEE 754R meeting notes on overflow and underflow.

  9. Milakov, M., & Gimelshein, N. (2018). Online Normalizer Calculation for Softmax. arXiv:1805.02867.

  10. Moler, C. (2017). Half Precision Arithmetic: 16-bit Floating Point Numbers. MATLAB Blog.