你正在训练一个Transformer模型,Loss曲线稳定下降,一切看起来都很顺利。然后你决定启用混合精度训练来加速——只需一行代码.half()。100步之后:Loss: NaN。训练彻底崩溃。
模型没有忘记如何学习,数学公式也没有改变。但GPU能表示的数字范围变了。在深度学习中,这个差异比你想象的更重要。
这不是一个罕见的问题。GPT-3、BERT以及每一个现代Transformer模型都曾面对这个挑战。让我们从头理解这个问题的本质,以及为什么一行简单的x - max(x)能够拯救你的训练。

浮点数:计算机如何存储"实数"
当你写下weight = 2.5时,GPU存储的不是精确的2.5,而是一个二进制近似。IEEE 754标准定义了浮点数的存储方式,它使用三个部分来表示一个数:
其中s是符号位,e是指数位,fraction是尾数(也叫有效数字)。这种表示方式像科学计数法,但有一个关键限制:位数是固定的。
FP32(单精度浮点数) 使用32位存储:
- 1位符号
- 8位指数
- 23位尾数
这给出了约$\pm 3.4 \times 10^{38}$的范围和约7位十进制精度。
FP16(半精度浮点数) 使用16位存储:
- 1位符号
- 5位指数
- 10位尾数
范围缩减到$\pm 65,504$,精度只有约3位十进制数字。

这里有一个关键洞察:FP16的动态范围严重不足。最大值65,504意味着当输入$x > 11$时,$e^x$就会溢出——因为$e^{11} \approx 59,874$,已经接近FP16的表示上限。
import numpy as np
# FP16的噩梦
x = np.array([10.0], dtype=np.float16)
print(np.exp(x)) # 正常: [22026.]
x = np.array([12.0], dtype=np.float16)
print(np.exp(x)) # 溢出: [inf]
这就是为什么在混合精度训练中,看似正常的计算会突然崩溃。
Softmax:一个对数值稳定性极度敏感的函数
Softmax函数将一组实数转换为概率分布:
$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$$这个定义在数学上完美无瑕,但在数值计算中存在两个致命陷阱:

溢出陷阱
当输入值很大时,指数运算会迅速超过浮点数的表示范围:
def unstable_softmax(x):
return np.exp(x) / np.sum(np.exp(x))
# 看似正常的输入
x = np.array([1000.0, 1000.0, 1000.0])
print(unstable_softmax(x))
# 输出: [nan, nan, nan]
# 原因: exp(1000) = inf, inf/inf = nan
在FP16下,这个问题更加严重。Transformer的注意力分数(QK^T的结果)可能产生很大的值,特别是对于长序列。当序列长度增加时,点积结果的方差也随之增大,很容易产生超过11的值。
下溢陷阱
当输入值是很小的负数时,指数运算会得到接近零的结果:
x = np.array([-1000.0, -1000.0, -1000.0])
print(np.exp(x))
# 输出: [0., 0., 0.]
下溢本身不会立即导致崩溃,但会使分母变成零,或者在后续计算中丢失所有精度。
Log-Sum-Exp技巧:一个优雅的数学救星
核心观察:在Softmax的分子分母中,我们可以乘以任意常数而不改变结果。这个看似平凡的观察,正是解决数值稳定性的钥匙。

如果我们选择$\ln C = -\max(x)$,那么:
$$\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}$$这个变换的精妙之处在于:
-
防止溢出:减去最大值后,所有指数的输入$x_i - \max(x) \leq 0$,因此$e^{x_i - \max(x)} \leq 1$。永远不会溢出。
-
保证至少一个非零值:当$x_i = \max(x)$时,$e^0 = 1$。分母至少为1,不会出现除以零的错误。
-
数学等价:结果与原始Softmax完全相同,只是计算过程更稳定。
def stable_softmax(x):
x_shifted = x - np.max(x)
exp_x = np.exp(x_shifted)
return exp_x / np.sum(exp_x)
# 现在可以处理任意大的输入
x = np.array([1000.0, 1000.0, 1000.0])
print(stable_softmax(x))
# 输出: [0.333..., 0.333..., 0.333...]
x = np.array([1.0, 2.0, 3.0])
print(stable_softmax(x))
# 输出: [0.0900..., 0.2447..., 0.6652...]
Log-Sum-Exp函数
在实践中,我们经常需要计算Log-Softmax(Softmax的对数),特别是在交叉熵损失中。直接计算log(softmax(x))会在概率接近零时产生log(0) = -inf的问题。
推导过程:
$$\log(\text{softmax}(x_i)) = \log\left(\frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}\right)$$$$= (x_i - \max(x)) - \log\left(\sum_j e^{x_j - \max(x)}\right)$$这里$\log\sum e^{x_j - \max(x)}$就是著名的Log-Sum-Exp (LSE) 函数。由于$\sum_j e^{x_j - \max(x)} \geq 1$(至少有一个$e^0 = 1$),这个对数永远不会出现$\log(0)$的问题。
def log_softmax(x):
x_max = np.max(x)
return x - x_max - np.log(np.sum(np.exp(x - x_max)))
def cross_entropy_loss(logits, target):
"""数值稳定的交叉熵损失"""
return -log_softmax(logits)[target]
PyTorch的F.log_softmax和F.cross_entropy底层就是这样实现的,它们将Softmax和对数操作融合在一起,既提高了数值稳定性,又减少了计算开销。
混合精度训练:当FP16遇上梯度下溢
启用了FP16混合精度训练后,新的问题出现了。即使Softmax本身是稳定的,梯度可能在其他地方崩溃。

梯度下溢问题
训练过程中,梯度的大小往往随着反向传播逐层减小:
Layer 1: gradient = 0.0001
Layer 2: gradient = 0.00003
Layer 3: gradient = 0.000007
Layer 4: gradient = 0.0000001
FP16能表示的最小正数约为$6 \times 10^{-8}$(归一化数)。第四层的梯度$1 \times 10^{-7}$会被舍入为零。
一项2018年的研究发现,在Multibox SSD检测器网络上,31%的梯度在转换为FP16时变成零,只有5.3%保持非零。训练完全发散。
Loss Scaling:强行放大梯度
解决方案非常直观:如果梯度太小,就在存储到FP16之前先放大。
静态Loss Scaling:
# 手动实现loss scaling
scale = 1024.0
loss = model(input)
scaled_loss = loss * scale
scaled_loss.backward()
# 梯度现在是原来的1024倍,可以安全存储在FP16中
for param in model.parameters():
param.grad = param.grad / scale # 恢复原始大小
optimizer.step()
动态Loss Scaling(PyTorch的GradScaler):
静态缩放的问题是:缩放因子太小则仍有下溢,太大则可能溢出。动态缩放自动调整:
- 从较大的缩放因子开始(如65536)
- 每次反向传播后检查梯度是否有inf/nan
- 如果有:跳过本次更新,缩放因子减半
- 如果连续N次(如2000次)没有问题:缩放因子翻倍
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
for batch in dataloader:
with autocast(): # 自动选择FP16计算
loss = model(batch)
scaler.scale(loss).backward() # 自动缩放
scaler.step(optimizer) # 检查梯度,必要时跳过更新
scaler.update() # 调整缩放因子
optimizer.zero_grad()
Master Weights:累积小更新的智慧
即使解决了梯度下溢,还有一个隐蔽的问题:权重更新的精度损失。
FP16的精度约为3位十进制数字。当权重接近1时,小于0.001的更新会被舍入掉:
weight_fp16 = 1.0
update = 0.0001
weight_fp16 += update # 结果仍然是1.0!
解决方案是维护FP32的"主权重"副本。每次更新时,在FP32精度下累积更新,然后截断到FP16用于前向传播:
# 伪代码展示原理
master_weight_fp32 = model.weight.clone().float()
# 前向传播使用FP16
output = model(input.half())
# 反向传播
loss.backward()
# 在FP32中累积更新
update = learning_rate * model.weight.grad.float()
master_weight_fp32 += update # 精度完整保留
# 复制回FP16
model.weight.data = master_weight_fp32.half()
这就是为什么混合精度训练框架(如PyTorch AMP)在启用时,优化器内部自动维护FP32主权重。
BFloat16:用精度换范围的明智选择
FP16的根本问题是5位指数带来的窄动态范围。Google提出了一个大胆的想法:保留FP32的指数位,只减少尾数。
BFloat16 (Brain Float 16) 格式:
- 1位符号
- 8位指数(与FP32相同)
- 7位尾数
| 格式 | 指数位 | 尾数位 | 动态范围 | 精度 |
|---|---|---|---|---|
| FP32 | 8 | 23 | $\pm 3.4 \times 10^{38}$ | ~7位十进制 |
| FP16 | 5 | 10 | $\pm 65,504$ | ~3位十进制 |
| BF16 | 8 | 7 | $\pm 3.4 \times 10^{38}$ | ~2位十进制 |
BF16拥有与FP32相同的动态范围,这意味着:
- Softmax计算不需要特殊的溢出处理
- 梯度下溢极少发生
- 不需要Loss Scaling
# BF16训练:简单直接
with autocast(dtype=torch.bfloat16):
loss = model(batch)
loss.backward()
optimizer.step() # 没有GradScaler,没有额外技巧
这正是为什么A100/H100等新一代GPU和TPU原生支持BF16的原因。对于大模型训练,BF16是比FP16更好的默认选择——它"开箱即用"。
Flash Attention中的在线Softmax
Flash Attention是一种内存高效的注意力实现,它通过将注意力计算分块处理来减少内存访问。但这带来了一个挑战:如何在分块计算中保持Softmax的数值稳定性?

标准Softmax需要三次遍历:
- 找到最大值
- 计算指数和
- 归一化
而Flash Attention使用在线Softmax算法,将前两次遍历合并:
def online_softmax(x):
"""在线Softmax:一次遍历计算max和sum"""
n = len(x)
# 初始化:第一个元素
m = x[0] # 当前最大值
d = 1.0 # 当前归一化因子
for i in range(1, n):
if x[i] > m:
# 发现新的最大值,需要调整之前的累积和
d = d * np.exp(m - x[i]) + 1.0
m = x[i]
else:
d += np.exp(x[i] - m)
# 最终归一化
return np.exp(x - m) / d
关键洞察:当我们发现新的最大值$m_{new}$时,之前计算的指数和需要重新归一化。所有之前累加的项都要乘以$e^{m_{old} - m_{new}}$来补偿基线的变化。
这个算法使得Flash Attention能够逐块处理注意力矩阵,同时在SRAM中保持数值稳定性,避免了将整个注意力矩阵写入高延迟的HBM。
实践指南:诊断和解决数值问题
症状1:训练开始后Loss变成NaN
检查梯度范数:
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
print(f"Gradient norm: {total_norm:.4f}")
- 健康值:0.1-10
- 爆炸(>100):添加梯度裁剪
- 消失(<0.001):检查学习率或网络结构
症状2:Loss Scale持续下降
print(f"Current loss scale: {scaler.get_scale()}")
如果缩放因子降到1.0并保持:
- 考虑切换到BF16
- 检查是否存在数值不稳定的层(如极大激活值)
症状3:特定层梯度为零
for name, param in model.named_parameters():
if param.grad is not None:
grad_mean = param.grad.abs().mean().item()
if grad_mean < 1e-10:
print(f"Zero gradient in {name}")
最佳实践清单
优先使用BF16(如果硬件支持):
- A100/H100 GPU
- TPU
- 不需要Loss Scaling
FP16训练的正确方式:
from torch.cuda.amp import GradScaler, autocast
model = Model().cuda()
optimizer = torch.optim.AdamW(model.parameters())
scaler = GradScaler()
for batch in dataloader:
with autocast():
output = model(batch)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
避免手动.half()转换:
# 错误:直接转FP16会导致训练崩溃
model = Model().half()
# 正确:使用autocast自动管理精度
with autocast():
output = model(input)
敏感操作保持FP32: PyTorch的autocast自动处理以下操作:
- Layer Normalization
- Batch Normalization
- Softmax
- Cross Entropy Loss
如果手动实现,确保这些操作在FP32精度下执行。
从数学到工程的完整链条
数值稳定性问题从浮点数的物理表示开始,贯穿Softmax的数学定义、梯度计算的精度损失、混合精度训练的权衡,直到Flash Attention的分块算法设计。
每一步都不是孤立的优化技巧,而是对"计算机如何表示数字"这一基础约束的系统性回应:
- Log-Sum-Exp:利用Softmax的数学性质(平移不变性)消除溢出风险
- Loss Scaling:用链式法则将梯度临时放大到可表示范围
- Master Weights:在高精度空间累积微小的权重更新
- BF16:从格式设计层面解决动态范围不足的问题
- Online Softmax:在分块计算中保持数值稳定性的增量算法
理解这些技术背后的数学原理,才能在面对新的训练崩溃时,系统性地定位和解决问题。当你的模型Loss变成NaN时,不再是无助的调试,而是基于浮点数表示限制的理性诊断。
参考资料:
-
Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press. Chapter 4: Numerical Computation.
-
NVIDIA. (2023). Mixed Precision Training Documentation. NVIDIA Deep Learning Performance.
-
Wang, S., & Kanwar, P. (2019). BFloat16: The Secret Behind High Performance on Cloud TPUs. Google Cloud Blog.
-
Dao, T., et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
-
Micikevicius, P., et al. (2018). Mixed Precision Training. ICLR 2018.
-
Dive into Deep Learning (D2L). Chapter 4.5: Numerical Stability and Initialization.
-
PyTorch Documentation. Automatic Mixed Precision (AMP) Package.
-
Kahan, W. (1996). IEEE 754R meeting notes on overflow and underflow.
-
Milakov, M., & Gimelshein, N. (2018). Online Normalizer Calculation for Softmax. arXiv:1805.02867.
-
Moler, C. (2017). Half Precision Arithmetic: 16-bit Floating Point Numbers. MATLAB Blog.