2019年某天凌晨三点,一位研究生的模型训练在运行72小时后突然输出全NaN。损失函数从平稳下降变成无穷大,权重变成NaN,一切归零。这是深度学习开发者的噩梦——数值不稳定。这不是偶然,而是浮点数表示固有的物理限制在特定条件下的必然爆发。
浮点数的物理边界:为什么计算机无法精确表示所有数字
理解数值稳定性问题的第一步是理解计算机如何表示数字。IEEE 754标准定义了浮点数的二进制表示,但这套标准存在根本性的物理限制。
IEEE 754格式解析
一个浮点数由三部分组成:符号位(S)、指数位(E)和尾数位(M)。数值计算公式为:
$$x = (-1)^S \times M \times 2^{E - \text{bias}}$$不同精度的格式有显著差异:
| 格式 | 符号位 | 指数位 | 尾数位 | 正数范围 | 最小正数 |
|---|---|---|---|---|---|
| FP32 | 1 | 8 | 23 | $[1.2 \times 10^{-38}, 3.4 \times 10^{38}]$ | $\approx 1.2 \times 10^{-38}$ |
| FP16 | 1 | 5 | 10 | $[6.0 \times 10^{-8}, 65504]$ | $\approx 6.0 \times 10^{-8}$ |
| BF16 | 1 | 8 | 7 | $[1.2 \times 10^{-38}, 3.4 \times 10^{38}]$ | $\approx 1.2 \times 10^{-38}$ |
block-beta
columns 6
block:fp32:6
columns 32
S["S (1)"]
E["E (8)"]
M["M (23)"]
end
block:fp16:6
columns 16
S2["S (1)"]
E2["E (5)"]
M2["M (10)"]
end
block:bf16:6
columns 16
S3["S (1)"]
E3["E (8)"]
M3["M (7)"]
end
上图展示了三种浮点数格式的位分布。这个表格揭示了一个关键洞察:FP16的动态范围远小于FP32。FP16的最大值只有65504,最小正数约为$6 \times 10^{-8}$。这意味着当数值超过65504时会溢出变成inf,低于$6 \times 10^{-8}$时会下溢变成0。
BF16(Brain Float 16)的设计哲学完全不同:它保留了FP32的8位指数,只牺牲尾数精度。这使BF16拥有与FP32相同的动态范围,代价是精度降低。这个设计选择直接解决了FP16训练中的主要数值问题。
溢出与下溢的本质
import numpy as np
# FP32正常工作
print(np.float32([1000, 1000, 1000])) # [1000. 1000. 1000.]
# FP16溢出
print(np.float16([1000, 1000, 1000])) # [inf inf inf]
# FP16下溢
print(np.float16([1e-10, 1e-10, 1e-10])) # [0. 0. 0.]
溢出产生inf(无穷大),下溢产生0。inf会污染所有后续计算(inf - inf = nan),而0可能让梯度消失,导致模型无法学习。
flowchart LR
subgraph 正常范围
A[数值在表示范围内]
end
subgraph 溢出
B[数值 > 最大值]
C[结果 = inf]
end
subgraph 下溢
D[数值 < 最小正数]
E[结果 = 0]
end
A --> F[正常计算]
B --> C
D --> E
C --> G[后续计算污染<br/>inf - inf = NaN]
E --> H[梯度消失<br/>模型无法学习]
Softmax的数值陷阱:指数运算的放大效应
Softmax是深度学习中最常用的激活函数之一,但它也是数值不稳定的主要来源。
标准Softmax的问题
Softmax定义如下:
$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}$$问题出在指数运算。当$x_i$较大时,$e^{x_i}$可能溢出:
def unstable_softmax(x):
exp_x = np.exp(x)
return exp_x / np.sum(exp_x)
# 正常情况
unstable_softmax([1, 2, 3]) # [0.090, 0.245, 0.665]
# 数值爆炸
unstable_softmax([1000, 2000, 3000]) # [nan, nan, nan]
# 因为 exp(1000) = inf, exp(2000) = inf, exp(3000) = inf
# inf / inf = nan
在FP16精度下,问题更加严重:
# FP16下的不稳定Softmax
x = np.array([6.0, -3, 15], dtype=np.float16)
unstable_softmax(x) # [0., 0., nan]
exp(15)在FP16下溢出,导致分母为inf,结果为nan。
Log-Sum-Exp技巧
解决方案是一个巧妙的数学恒等式。注意到:
$$\frac{e^{x_i}}{\sum_j e^{x_j}} = \frac{e^{x_i - c} \cdot e^c}{\sum_j e^{x_j - c} \cdot e^c} = \frac{e^{x_i - c}}{\sum_j e^{x_j - c}}$$其中$c$是任意常数。如果我们选择$c = \max(x)$,则$x_i - c \leq 0$,所以$e^{x_i - c} \leq 1$。这完全避免了溢出:
def stable_softmax(x):
x_max = np.max(x)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x)
# 现在可以处理任意大的输入
stable_softmax([1000, 2000, 3000]) # [0., 0., 1.]
stable_softmax(np.array([6.0, -3, 15], dtype=np.float16)) # [1.23e-04, 0., 1.]
flowchart TB
subgraph 不稳定Softmax
A[输入: 1000, 2000, 3000] --> B[exp计算]
B --> C["exp(1000)=inf<br/>exp(2000)=inf<br/>exp(3000)=inf"]
C --> D["归一化: inf/inf"]
D --> E[结果: NaN]
end
subgraph 稳定Softmax
F[输入: 1000, 2000, 3000] --> G[减去最大值: -2000, 0, 1000]
G --> H["exp计算"]
H --> I["exp(-2000)≈0<br/>exp(0)=1<br/>exp(1000)=有限值"]
I --> J[归一化]
J --> K[结果: 0, 0, 1]
end
这个技巧的本质是平移不变性:在分子分母同时减去最大值,不改变概率分布,但将所有指数运算限制在安全范围内。
Log-Sum-Exp函数的稳定实现
在计算对数概率时,我们需要稳定地计算$\log(\sum_i e^{x_i})$。直接计算会遭遇相同的溢出问题:
$$\log\sum_i e^{x_i} = \log\sum_i e^{x_i - c + c} = c + \log\sum_i e^{x_i - c}$$其中$c = \max(x)$:
def logsumexp(x):
c = np.max(x)
return c + np.log(np.sum(np.exp(x - c)))
# 可以处理任意范围的输入
logsumexp([1000, 1000, 1000]) # 1001.099 (不是inf)
logsumexp([-1000, -1000, -1000]) # -998.901 (不是-inf)
Log-Softmax的稳定实现
交叉熵损失通常涉及$\log(\text{softmax}(x))$。直接组合log和softmax会遭遇log(0)问题:
$$\log(\text{softmax}(x)_i) = \log\frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} = x_i - \max(x) - \log\sum_j e^{x_j - \max(x)}$$def log_softmax(x):
x_max = np.max(x)
return x - x_max - np.log(np.sum(np.exp(x - x_max)))
def cross_entropy_loss(logits, target):
return -log_softmax(logits)[target]
# 即使target对应的softmax概率接近0,也能得到有限值
cross_entropy_loss([-1000, 1000], 0) # 2000.0 (不是inf)
混合精度训练:在速度与精度之间走钢丝
现代GPU对FP16运算有专门的硬件加速(Tensor Core),吞吐量可达FP32的2-8倍。但FP16的窄动态范围会破坏训练稳定性。混合精度训练试图在两者之间取得平衡。
FP16训练的三大问题
NVIDIA 2017年的开创性论文《Mixed Precision Training》系统分析了FP16训练的三个核心问题:
问题一:权重更新下溢
权重更新通常是$\Delta w = \eta \cdot \nabla w$,其中$\eta$(学习率)通常在$10^{-3}$到$10^{-5}$量级,而梯度可能更小。FP16的最小正数约为$6 \times 10^{-8}$,当$\eta \cdot \nabla w < 6 \times 10^{-8}$时,更新量变成0,模型无法学习。
实验数据表明,约5%的权重梯度值指数小于-24(即值小于$10^{-24}$),这些值在FP16中会变成零。
问题二:权重更新精度损失
即使$\Delta w$能被FP16表示,当$|w| \gg |\Delta w|$时,$w + \Delta w$可能因为浮点数对齐而丢失$\Delta w$。具体来说,当$|w| / |\Delta w| > 2048$时,右移超过11位(FP16有10位尾数),$\Delta w$会完全消失。
问题三:梯度下溢
深度网络的反向传播中,梯度可能逐层衰减。FP16的限制梯度会在传播过程中丢失,尤其是来自Softmax层的梯度。
flowchart TB
subgraph 问题链
A[FP16训练] --> B{三大问题}
B --> C[权重更新下溢<br/>η·∇w < 6×10⁻⁸]
B --> D[精度损失<br/>w + Δw 丢失Δw]
B --> E[梯度下溢<br/>反向传播衰减]
C --> F[模型无法学习]
D --> G[训练停滞]
E --> H[梯度消失]
end
三大解决方案
方案一:FP32主权重副本
# 维护FP32主权重
master_weights = model.parameters().float().clone() # FP32副本
for input, target in data:
# 前向传播使用FP16
fp16_weights = master_weights.half()
output = model(input, fp16_weights)
loss = criterion(output, target)
# 反向传播
loss.backward()
# 更新FP32主权重
for p, master_p in zip(model.parameters(), master_weights):
master_p += lr * p.grad.float() # 梯度转FP32后更新
# 同步回FP16
p.data = master_p.half()
方案二:损失缩放
梯度下溢可以通过放大损失来缓解:
scale_factor = 8192 # 放大因子
# 前向传播
output = model(input)
loss = criterion(output, target)
# 缩放损失
scaled_loss = loss * scale_factor
# 反向传播(梯度自动放大scale_factor倍)
scaled_loss.backward()
# 反缩放梯度
for p in model.parameters():
p.grad = p.grad / scale_factor
# 更新权重
optimizer.step()
关键洞察:缩放因子可以很大(如32768),只要不导致梯度溢出。溢出可以被检测(出现inf/nan),检测到时可以跳过本次更新或减小缩放因子。
方案三:FP32累积
矩阵乘法的部分积应该用FP32累积:
$$C_{ij} = \sum_k A_{ik} \cdot B_{kj}$$每个乘积$A_{ik} \cdot B_{kj}$可能在FP16中,但求和应该用FP32,最后再转回FP16存储。现代GPU的Tensor Core原生支持这种操作。
flowchart LR
subgraph 混合精度训练流水线
A[FP32主权重] -->|转换| B[FP16计算权重]
B --> C[FP16前向传播]
C --> D[计算损失]
D --> E[损失缩放]
E --> F[FP16反向传播]
F --> G[梯度反缩放]
G --> H[FP32权重更新]
H --> A
end
动态损失缩放
静态缩放因子需要手动调参。动态缩放自动调整:
class DynamicGradScaler:
def __init__(self, init_scale=2**15, growth_factor=2.0,
backoff_factor=0.5, growth_interval=2000):
self.scale = init_scale
self.growth_factor = growth_factor
self.backoff_factor = backoff_factor
self.growth_interval = growth_interval
self._growth_tracker = 0
def scale_loss(self, loss):
return loss * self.scale
def unscale_grads(self, optimizer):
for p in optimizer.param_groups[0]['params']:
if p.grad is not None:
p.grad = p.grad / self.scale
def update(self, found_inf):
if found_inf:
# 检测到溢出,减小缩放因子
self.scale *= self.backoff_factor
self._growth_tracker = 0
else:
# 正常更新,尝试增大缩放因子
self._growth_tracker += 1
if self._growth_tracker >= self.growth_interval:
self.scale *= self.growth_factor
self._growth_tracker = 0
BF16:免调参的解决方案
BF16的设计从根本上解决了FP16的问题:
| 特性 | FP16 | BF16 |
|---|---|---|
| 指数位 | 5位 | 8位 |
| 尾数位 | 10位 | 7位 |
| 动态范围 | 窄(与FP32差异大) | 宽(与FP32相同) |
| 精度 | 高 | 低 |
| 需要损失缩放 | 通常需要 | 通常不需要 |
BF16保留了FP32的动态范围,这意味着:
- 梯度下溢风险大幅降低
- 权重更新下溢风险大幅降低
- 在大多数情况下不需要损失缩放
代价是精度降低(只有7位尾数),但深度学习的数值研究表明,这种精度损失对模型质量影响有限。
graph TB
subgraph FP16训练
A1[需要损失缩放] --> B1[需要FP32主权重]
B1 --> C1[动态缩放调参]
C1 --> D1[复杂度高]
end
subgraph BF16训练
A2[通常不需要损失缩放] --> B2[FP32主权重推荐但非必需]
B2 --> C2[无需缩放调参]
C2 --> D2[复杂度低]
end
# PyTorch中使用BF16
model = model.to(torch.bfloat16)
# BF16训练通常不需要GradScaler
# 但仍需要FP32主权重
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
optimizer.zero_grad()
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(batch)
loss = criterion(output)
loss.backward()
optimizer.step()
Flash Attention:分块计算的数值挑战
Flash Attention通过分块计算减少显存访问,但分块Softmax带来了新的数值稳定性挑战。
分块Softmax的困难
标准Softmax需要全局归一化,但Flash Attention将输入分成小块处理。如何保证分块计算的数值稳定性?
关键洞察是在线Softmax算法。假设我们要计算:
$$\text{softmax}(x) = \frac{e^{x}}{\sum_j e^{x_j}}$$将$x$分成两块$x^{(1)}$和$x^{(2)}$:
$$\sum_j e^{x_j} = \sum_{j \in \text{block}_1} e^{x_j} + \sum_{j \in \text{block}_2} e^{x_j}$$但这在数值上不稳定。正确的方法是跟踪每块的最大值:
def online_softmax(x_blocks):
"""在线Softmax算法"""
global_max = -float('inf')
global_sum = 0.0
for block in x_blocks:
block_max = np.max(block)
if global_max == -float('inf'):
global_max = block_max
global_sum = np.sum(np.exp(block - block_max))
else:
# 关键:需要重新缩放之前的累积和
if block_max > global_max:
scale = np.exp(global_max - block_max)
global_sum = global_sum * scale + np.sum(np.exp(block - block_max))
global_max = block_max
else:
scale = np.exp(block_max - global_max)
global_sum = global_sum + np.sum(np.exp(block - global_max))
return global_max + np.log(global_sum)
flowchart LR
subgraph 在线Softmax
A[块1] --> C[当前最大值 m1<br/>累积和 exp x-m1]
B[块2] --> D[新最大值 m2<br/>新累积和]
C --> E{m2 > m1?}
D --> E
E -->|是| F[重新缩放旧累积和<br/>×exp m1-m2]
E -->|否| G[缩放新累积和<br/>×exp m2-m1]
F --> H[合并累积和]
G --> H
end
Flash Attention的稳定实现
Flash Attention的核心是分块计算注意力,同时保持数值稳定性:
def flash_attention(Q, K, V, block_size=256):
"""简化的Flash Attention实现"""
seq_len = Q.shape[0]
output = np.zeros_like(Q)
for i in range(0, seq_len, block_size):
q_block = Q[i:i+block_size]
block_max = np.full(block_size, -float('inf'))
block_sum = np.zeros(block_size)
block_output = np.zeros_like(q_block)
for j in range(0, seq_len, block_size):
k_block = K[j:j+block_size]
v_block = V[j:j+block_size]
# 计算当前块的注意力分数
scores = np.dot(q_block, k_block.T)
# 在线更新最大值和归一化因子
new_max = np.maximum(block_max, np.max(scores, axis=1))
scale = np.exp(block_max - new_max)
scale_scores = np.exp(scores - new_max[:, None])
block_sum = block_sum * scale + np.sum(scale_scores, axis=1)
block_output = block_output * scale[:, None] + np.dot(scale_scores, v_block)
block_max = new_max
output[i:i+block_size] = block_output / block_sum[:, None]
return output
关键点:
- 每块计算时减去当前最大值防止溢出
- 跨块合并时需要重新缩放之前的结果
- 最终归一化在所有块处理完后进行
梯度问题:深度网络的数学诅咒
梯度消失和爆炸是深度学习的经典问题,其根源在于反向传播的链式法则。
数学分析
考虑一个$L$层的深度网络,每层的激活函数为$\phi$。第$l$层的梯度为:
$$\frac{\partial L}{\partial W_l} = \frac{\partial L}{\partial a_L} \prod_{k=l+1}^{L} \frac{\partial a_k}{\partial a_{k-1}} \cdot \frac{\partial a_l}{\partial W_l}$$关键在于连乘项$\prod_{k=l+1}^{L} \frac{\partial a_k}{\partial a_{k-1}}$。假设每层的导数近似为$\gamma$,则$L-l$层后的梯度大约为$\gamma^{L-l}$。
- 当$|\gamma| < 1$时,$\gamma^{L-l} \to 0$(梯度消失)
- 当$|\gamma| > 1$时,$\gamma^{L-l} \to \infty$(梯度爆炸)
graph LR
subgraph 反向传播梯度流
A[输出层<br/>梯度=1.0] --> B[第L-1层<br/>梯度×σ']
B --> C[第L-2层<br/>梯度×σ']
C --> D[...]
D --> E[第1层<br/>梯度×σ'^L]
end
subgraph 梯度消失
F["σ'=0.25<br/>10层后: 0.25^10≈10^-6"]
end
subgraph 梯度爆炸
G["σ'>1<br/>可能发散"]
end
Sigmoid的问题
Sigmoid函数$\sigma(x) = \frac{1}{1+e^{-x}}$的导数为:
$$\sigma'(x) = \sigma(x)(1-\sigma(x)) \leq 0.25$$最大导数只有0.25。对于10层网络,梯度可能衰减到$0.25^{10} \approx 10^{-6}$,在FP16下直接变成0。
ReLU的改进
ReLU函数$f(x) = \max(0, x)$的导数为:
$$f'(x) = \begin{cases} 1 & x > 0 \\ 0 & x \leq 0 \end{cases}$$当$x > 0$时,导数为1,不会衰减。这解决了梯度消失问题(在激活区域),但可能引入梯度爆炸。
梯度裁剪
梯度裁剪是解决梯度爆炸的常用技术:
def clip_grad_norm_(parameters, max_norm, norm_type=2):
"""梯度范数裁剪"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == float('inf'):
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
return total_norm
两种常见策略:
- 按值裁剪:将梯度限制在$[-c, c]$
- 按范数裁剪:当梯度范数超过阈值时,缩放梯度使其范数等于阈值
Xavier与He初始化
合理的权重初始化可以从源头控制梯度问题。
Xavier初始化(适用于tanh/sigmoid):
$$W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in} + n_{out}}}\right)$$设计目标:使每层的输入方差等于输出方差,保持信号强度。
He初始化(适用于ReLU):
$$W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in}}}\right)$$考虑到ReLU会"杀死"一半的神经元,需要额外的$\sqrt{2}$补偿。
Loss Spike:Transformer训练的幽灵
大模型训练中常见一种现象:损失函数突然跳升几个数量级,然后可能恢复,也可能发散。这就是Loss Spike。
Loss Spike的成因
研究表明,Loss Spike通常由以下因素触发:
- 数据异常:训练数据中的异常样本可能产生极大的梯度
- 学习率过高:过大的学习率使优化器跳过局部最优,进入损失地形的高曲率区域
- 注意力分布坍塌:Softmax输出的极端分布(如[0, 0, …, 1, …0])导致梯度不稳定
- BatchNorm统计异常:小批量中的极端样本扭曲running statistics
graph TB
subgraph Loss Spike触发因素
A[数据异常样本] --> E[极大梯度]
B[学习率过高] --> F[跳过最优区]
C[注意力分布坍塌] --> G[梯度不稳定]
D[BatchNorm统计异常] --> H[归一化失效]
end
E --> I[梯度范数爆炸]
F --> I
G --> I
H --> I
I --> J[Loss Spike]
检测与缓解
class LossSpikeHandler:
def __init__(self, threshold=5.0, patience=3):
self.threshold = threshold # 相对于滑动平均的倍数
self.patience = patience
self.loss_history = []
self.spike_count = 0
def check_spike(self, loss):
self.loss_history.append(loss)
if len(self.loss_history) < 10:
return False
# 计算滑动平均和标准差
recent = self.loss_history[-10:]
mean = np.mean(recent)
std = np.std(recent)
# 检测异常
is_spike = loss > mean + self.threshold * std
if is_spike:
self.spike_count += 1
if self.spike_count >= self.patience:
return True # 需要干预
else:
self.spike_count = 0
return False
def intervene(self, optimizer, original_lr):
"""干预策略"""
# 降低学习率
for param_group in optimizer.param_groups:
param_group['lr'] = original_lr * 0.1
ZClip:自适应梯度裁剪
2025年提出的ZClip算法提供了一种自适应的解决方案:
class ZClip:
"""自适应梯度裁剪,专门针对Loss Spike"""
def __init__(self, z_threshold=3.0, min_grad_norm=1e-6):
self.z_threshold = z_threshold
self.grad_norm_history = []
self.min_grad_norm = min_grad_norm
def clip(self, parameters):
# 计算当前梯度范数
total_norm = 0
for p in parameters:
if p.grad is not None:
total_norm += p.grad.data.norm() ** 2
total_norm = total_norm ** 0.5
self.grad_norm_history.append(total_norm)
if len(self.grad_norm_history) < 10:
return total_norm
# 计算Z分数
recent = self.grad_norm_history[-10:]
mean = np.mean(recent)
std = np.std(recent)
z_score = (total_norm - mean) / (std + self.min_grad_norm)
# 如果Z分数超过阈值,裁剪
if z_score > self.z_threshold:
clip_coef = mean / (total_norm + self.min_grad_norm)
for p in parameters:
if p.grad is not None:
p.grad.data.mul_(clip_coef)
return mean
return total_norm
PyTorch中的数值稳定性最佳实践
PyTorch提供了多种内置工具来处理数值稳定性问题。
自动混合精度(AMP)
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler(init_scale=2**15, growth_factor=2.0,
backoff_factor=0.5, growth_interval=2000)
for epoch in range(num_epochs):
for batch, target in dataloader:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(batch)
loss = criterion(output, target)
# 缩放损失并反向传播
scaler.scale(loss).backward()
# 反缩放梯度并裁剪
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 更新权重
scaler.step(optimizer)
scaler.update()
数值稳定的层
PyTorch提供了数值稳定的组合层:
import torch.nn.functional as F
# 不要分开使用
logits = model(x)
probs = F.softmax(logits, dim=-1)
log_probs = torch.log(probs) # 可能出现log(0)
# 使用组合版本
log_probs = F.log_softmax(logits, dim=-1) # 数值稳定
# 交叉熵损失也内置了稳定的组合
loss = F.cross_entropy(logits, targets) # 等价于 log_softmax + nll_loss
LayerNorm中的Epsilon
LayerNorm需要计算标准差,可能除以零。PyTorch内置了epsilon保护:
# LayerNorm实现
def layer_norm(x, normalized_shape, weight, bias, eps=1e-5):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + eps) # eps防止除零
return weight * x_norm + bias
默认eps=1e-5在FP32下足够,但在FP16训练中可能需要增大:
# FP16训练可能需要更大的epsilon
layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) # 或更大
调试数值问题:从NaN到根源
当模型输出NaN时,系统性的调试流程至关重要。
检测NaN/Inf
def check_tensor_health(tensor, name="tensor"):
"""检查张量的数值健康状况"""
has_nan = torch.isnan(tensor).any()
has_inf = torch.isinf(tensor).any()
if has_nan or has_inf:
print(f"[ALERT] {name}: NaN={has_nan.item()}, Inf={has_inf.item()}")
print(f" Shape: {tensor.shape}")
print(f" Min: {tensor[~torch.isnan(tensor) & ~torch.isinf(tensor)].min() if (~torch.isnan(tensor) & ~torch.isinf(tensor)).any() else 'N/A'}")
print(f" Max: {tensor[~torch.isnan(tensor) & ~torch.isinf(tensor)].max() if (~torch.isnan(tensor) & ~torch.isinf(tensor)).any() else 'N/A'}")
return False
return True
# 使用钩子监控每层输出
def register_health_checks(model):
for name, module in model.named_modules():
module.register_forward_hook(
lambda m, i, o, n=name: check_tensor_health(o, n) if o is not None else None
)
flowchart TB
subgraph NaN调试流程
A[检测到NaN] --> B{检查输入数据}
B -->|有NaN| C[修复数据预处理]
B -->|无NaN| D{检查第一层输出}
D -->|有NaN| E[检查初始化]
D -->|无NaN| F{逐层检查}
F -->|定位到某层| G{检查该层}
G --> H[学习率过高?]
G --> I[损失缩放?]
G --> J[归一化eps?]
G --> K[除法保护?]
end
常见原因清单
| 症状 | 可能原因 | 解决方案 |
|---|---|---|
| 训练初期就出现NaN | 学习率过高 | 降低学习率1-2个数量级 |
| 混合精度训练出现NaN | 损失缩放不当 | 使用动态缩放或BF16 |
| 特定批次触发NaN | 数据问题 | 检查数据预处理,添加NaN过滤 |
| 深层网络出现NaN | LayerNorm epsilon过小 | 增大eps参数 |
| Softmax后出现NaN | 除法运算无保护 | 添加epsilon保护 |
| 第一轮迭代就出现NaN | 权重初始化不当 | 使用Xavier或He初始化 |
FP8:数值稳定性的新前沿
FP8(8位浮点数)正在成为LLM训练的新标准,但带来了更严峻的数值挑战。
FP8格式
FP8有两种主流格式:
| 格式 | 符号位 | 指数位 | 尾数位 | 范围 |
|---|---|---|---|---|
| E4M3 | 1 | 4 | 3 | $[2^{-6}, 448]$ |
| E5M2 | 1 | 5 | 2 | $[2^{-14}, 57344]$ |
E4M3精度更高但范围更窄,适合存储激活和权重;E5M2范围更广但精度更低,适合存储梯度。
block-beta
columns 4
block:e4m3:4
columns 8
S1["S"]
E1["E"]
E2["E"]
E3["E"]
E4["E"]
M1["M"]
M2["M"]
M3["M"]
end
block:e5m2:4
columns 8
S2["S"]
E5["E"]
E6["E"]
E7["E"]
E8["E"]
E9["E"]
M4["M"]
M5["M"]
end
FP8训练的核心挑战
-
异常特征(Outlier Features):Transformer训练中会出现激活值比其他值大100-1000倍的异常特征,FP8无法同时表示这种动态范围
-
量化误差累积:每次FP8转换都会损失精度,多层累积后可能导致发散
-
梯度表示:梯度通常有更大的动态范围和更多的小值
解决方案:动态缩放
class FP8DynamicScaler:
"""FP8动态缩放"""
def __init__(self, margin=3):
self.margin = margin # 偏移量,避免溢出
self.scale = 1.0
def compute_scale(self, tensor, fp8_max=448.0):
"""计算最优缩放因子"""
tensor_max = tensor.abs().max()
if tensor_max == 0:
return self.scale
# 计算缩放因子使最大值接近FP8最大值
target_max = fp8_max / (2 ** self.margin)
new_scale = target_max / tensor_max
# 平滑更新
self.scale = 0.9 * self.scale + 0.1 * new_scale
return self.scale
def quantize(self, tensor):
scale = self.compute_scale(tensor)
return (tensor * scale).to(torch.int8).to(torch.float32) / scale
实践建议:构建数值稳定的训练流水线
综合以上分析,以下是构建数值稳定训练流水线的检查清单:
flowchart TB
subgraph 数据阶段
A1[检查极端值/NaN] --> A2[数据归一化]
A2 --> A3[数据增强验证]
end
subgraph 模型阶段
B1[权重初始化<br/>He/Xavier] --> B2[LayerNorm eps]
B2 --> B3[使用稳定API<br/>log_softmax]
end
subgraph 训练阶段
C1[混合精度<br/>优先BF16] --> C2[梯度裁剪]
C2 --> C3[监控梯度/损失]
end
subgraph 调试阶段
D1[健康检查钩子] --> D2[梯度统计日志]
D2 --> D3[确定性模式复现]
end
A3 --> B1
B3 --> C1
C3 --> D1
数据阶段
- 检查输入数据是否包含极端值或NaN
- 确保数据归一化到合理范围(如[-1, 1]或[0, 1])
- 添加数据增强时注意不产生数值异常
模型阶段
- 使用He初始化(ReLU)或Xavier初始化(tanh/sigmoid)
- LayerNorm/BatchNorm使用适当的epsilon
- 避免手动实现softmax,使用框架的log_softmax
训练阶段
- 使用混合精度训练(优先BF16)
- 启用梯度裁剪(norm clipping)
- 监控梯度范数和损失曲线
调试阶段
- 注册张量健康检查钩子
- 记录每层的梯度统计
- 使用确定性模式复现问题
数值稳定性是深度学习系统工程的基石。理解浮点数的物理限制、掌握Log-Sum-Exp等数值技巧、合理配置混合精度训练,这些都是现代深度学习工程师的必备技能。当代价是72小时的训练成果化为NaN时,预防永远比修复更重要。
参考资料
- Micikevicius, P., et al. “Mixed Precision Training.” ICLR 2018.
- Zhao, R., et al. “Reducing Underflow in Mixed Precision Training by Gradient Scaling.” IJCAI 2020.
- NVIDIA. “Training With Mixed Precision.” NVIDIA Developer Documentation.
- Kahan, W. “IEEE Standard 754 for Binary Floating-Point Arithmetic.” 1997.
- Blanchard, P., et al. “Mixed Precision Block Fused Multiply-Add: Error Analysis.” 2019.
- Dao, T., et al. “FlashAttention: Fast and Memory-Efficient Exact Attention.” NeurIPS 2022.
- Peng, B., et al. “FP8-LM: Training FP8 Large Language Models.” arXiv 2023.
- “Numerical Stability in Flash Attention.” jarbus.net, 2023.
- “Numerically Stable Softmax and Cross Entropy.” Jay Mody, 2022.
- Gundersen, G. “The Log-Sum-Exp Trick.” 2020.
- “Understanding the FP64, FP32, FP16, BFLOAT16, TF32, FP8 Formats.” jeffreytse.net, 2024.
- “BF16 vs FP16: A Comparison of Performance and Efficiency.” Beam Blog, 2025.
- “An Empirical Study on Numerical Bugs in Deep Learning Programs.” ACM Digital Library.
- “Loss spikes in training: causes, detection, and mitigations.” Medium, 2026.
- “ZClip: Adaptive Spike Mitigation for LLM Pre-Training.” arXiv 2025.
- “Enhancing LLM Pretraining Stability via Adaptive Gradient Clipping.” arXiv 2026.