2019年某天凌晨三点,一位研究生的模型训练在运行72小时后突然输出全NaN。损失函数从平稳下降变成无穷大,权重变成NaN,一切归零。这是深度学习开发者的噩梦——数值不稳定。这不是偶然,而是浮点数表示固有的物理限制在特定条件下的必然爆发。

浮点数的物理边界:为什么计算机无法精确表示所有数字

理解数值稳定性问题的第一步是理解计算机如何表示数字。IEEE 754标准定义了浮点数的二进制表示,但这套标准存在根本性的物理限制。

IEEE 754格式解析

一个浮点数由三部分组成:符号位(S)、指数位(E)和尾数位(M)。数值计算公式为:

$$x = (-1)^S \times M \times 2^{E - \text{bias}}$$

不同精度的格式有显著差异:

格式 符号位 指数位 尾数位 正数范围 最小正数
FP32 1 8 23 $[1.2 \times 10^{-38}, 3.4 \times 10^{38}]$ $\approx 1.2 \times 10^{-38}$
FP16 1 5 10 $[6.0 \times 10^{-8}, 65504]$ $\approx 6.0 \times 10^{-8}$
BF16 1 8 7 $[1.2 \times 10^{-38}, 3.4 \times 10^{38}]$ $\approx 1.2 \times 10^{-38}$
block-beta
    columns 6
    block:fp32:6
        columns 32
        S["S (1)"]
        E["E (8)"]
        M["M (23)"]
    end
    block:fp16:6
        columns 16
        S2["S (1)"]
        E2["E (5)"]
        M2["M (10)"]
    end
    block:bf16:6
        columns 16
        S3["S (1)"]
        E3["E (8)"]
        M3["M (7)"]
    end

上图展示了三种浮点数格式的位分布。这个表格揭示了一个关键洞察:FP16的动态范围远小于FP32。FP16的最大值只有65504,最小正数约为$6 \times 10^{-8}$。这意味着当数值超过65504时会溢出变成inf,低于$6 \times 10^{-8}$时会下溢变成0。

BF16(Brain Float 16)的设计哲学完全不同:它保留了FP32的8位指数,只牺牲尾数精度。这使BF16拥有与FP32相同的动态范围,代价是精度降低。这个设计选择直接解决了FP16训练中的主要数值问题。

溢出与下溢的本质

import numpy as np

# FP32正常工作
print(np.float32([1000, 1000, 1000]))  # [1000. 1000. 1000.]

# FP16溢出
print(np.float16([1000, 1000, 1000]))  # [inf inf inf]

# FP16下溢
print(np.float16([1e-10, 1e-10, 1e-10]))  # [0. 0. 0.]

溢出产生inf(无穷大),下溢产生0。inf会污染所有后续计算(inf - inf = nan),而0可能让梯度消失,导致模型无法学习。

flowchart LR
    subgraph 正常范围
        A[数值在表示范围内]
    end
    subgraph 溢出
        B[数值 > 最大值]
        C[结果 = inf]
    end
    subgraph 下溢
        D[数值 < 最小正数]
        E[结果 = 0]
    end
    A --> F[正常计算]
    B --> C
    D --> E
    C --> G[后续计算污染<br/>inf - inf = NaN]
    E --> H[梯度消失<br/>模型无法学习]

Softmax的数值陷阱:指数运算的放大效应

Softmax是深度学习中最常用的激活函数之一,但它也是数值不稳定的主要来源。

标准Softmax的问题

Softmax定义如下:

$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}$$

问题出在指数运算。当$x_i$较大时,$e^{x_i}$可能溢出:

def unstable_softmax(x):
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x)

# 正常情况
unstable_softmax([1, 2, 3])  # [0.090, 0.245, 0.665]

# 数值爆炸
unstable_softmax([1000, 2000, 3000])  # [nan, nan, nan]
# 因为 exp(1000) = inf, exp(2000) = inf, exp(3000) = inf
# inf / inf = nan

在FP16精度下,问题更加严重:

# FP16下的不稳定Softmax
x = np.array([6.0, -3, 15], dtype=np.float16)
unstable_softmax(x)  # [0., 0., nan]

exp(15)在FP16下溢出,导致分母为inf,结果为nan。

Log-Sum-Exp技巧

解决方案是一个巧妙的数学恒等式。注意到:

$$\frac{e^{x_i}}{\sum_j e^{x_j}} = \frac{e^{x_i - c} \cdot e^c}{\sum_j e^{x_j - c} \cdot e^c} = \frac{e^{x_i - c}}{\sum_j e^{x_j - c}}$$

其中$c$是任意常数。如果我们选择$c = \max(x)$,则$x_i - c \leq 0$,所以$e^{x_i - c} \leq 1$。这完全避免了溢出:

def stable_softmax(x):
    x_max = np.max(x)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x)

# 现在可以处理任意大的输入
stable_softmax([1000, 2000, 3000])  # [0., 0., 1.]
stable_softmax(np.array([6.0, -3, 15], dtype=np.float16))  # [1.23e-04, 0., 1.]
flowchart TB
    subgraph 不稳定Softmax
        A[输入: 1000, 2000, 3000] --> B[exp计算]
        B --> C["exp(1000)=inf<br/>exp(2000)=inf<br/>exp(3000)=inf"]
        C --> D["归一化: inf/inf"]
        D --> E[结果: NaN]
    end
    subgraph 稳定Softmax
        F[输入: 1000, 2000, 3000] --> G[减去最大值: -2000, 0, 1000]
        G --> H["exp计算"]
        H --> I["exp(-2000)≈0<br/>exp(0)=1<br/>exp(1000)=有限值"]
        I --> J[归一化]
        J --> K[结果: 0, 0, 1]
    end

这个技巧的本质是平移不变性:在分子分母同时减去最大值,不改变概率分布,但将所有指数运算限制在安全范围内。

Log-Sum-Exp函数的稳定实现

在计算对数概率时,我们需要稳定地计算$\log(\sum_i e^{x_i})$。直接计算会遭遇相同的溢出问题:

$$\log\sum_i e^{x_i} = \log\sum_i e^{x_i - c + c} = c + \log\sum_i e^{x_i - c}$$

其中$c = \max(x)$:

def logsumexp(x):
    c = np.max(x)
    return c + np.log(np.sum(np.exp(x - c)))

# 可以处理任意范围的输入
logsumexp([1000, 1000, 1000])  # 1001.099 (不是inf)
logsumexp([-1000, -1000, -1000])  # -998.901 (不是-inf)

Log-Softmax的稳定实现

交叉熵损失通常涉及$\log(\text{softmax}(x))$。直接组合log和softmax会遭遇log(0)问题:

$$\log(\text{softmax}(x)_i) = \log\frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} = x_i - \max(x) - \log\sum_j e^{x_j - \max(x)}$$
def log_softmax(x):
    x_max = np.max(x)
    return x - x_max - np.log(np.sum(np.exp(x - x_max)))

def cross_entropy_loss(logits, target):
    return -log_softmax(logits)[target]

# 即使target对应的softmax概率接近0,也能得到有限值
cross_entropy_loss([-1000, 1000], 0)  # 2000.0 (不是inf)

混合精度训练:在速度与精度之间走钢丝

现代GPU对FP16运算有专门的硬件加速(Tensor Core),吞吐量可达FP32的2-8倍。但FP16的窄动态范围会破坏训练稳定性。混合精度训练试图在两者之间取得平衡。

FP16训练的三大问题

NVIDIA 2017年的开创性论文《Mixed Precision Training》系统分析了FP16训练的三个核心问题:

问题一:权重更新下溢

权重更新通常是$\Delta w = \eta \cdot \nabla w$,其中$\eta$(学习率)通常在$10^{-3}$到$10^{-5}$量级,而梯度可能更小。FP16的最小正数约为$6 \times 10^{-8}$,当$\eta \cdot \nabla w < 6 \times 10^{-8}$时,更新量变成0,模型无法学习。

实验数据表明,约5%的权重梯度值指数小于-24(即值小于$10^{-24}$),这些值在FP16中会变成零。

问题二:权重更新精度损失

即使$\Delta w$能被FP16表示,当$|w| \gg |\Delta w|$时,$w + \Delta w$可能因为浮点数对齐而丢失$\Delta w$。具体来说,当$|w| / |\Delta w| > 2048$时,右移超过11位(FP16有10位尾数),$\Delta w$会完全消失。

问题三:梯度下溢

深度网络的反向传播中,梯度可能逐层衰减。FP16的限制梯度会在传播过程中丢失,尤其是来自Softmax层的梯度。

flowchart TB
    subgraph 问题链
        A[FP16训练] --> B{三大问题}
        B --> C[权重更新下溢<br/>η·∇w < 6×10⁻⁸]
        B --> D[精度损失<br/>w + Δw 丢失Δw]
        B --> E[梯度下溢<br/>反向传播衰减]
        C --> F[模型无法学习]
        D --> G[训练停滞]
        E --> H[梯度消失]
    end

三大解决方案

方案一:FP32主权重副本

# 维护FP32主权重
master_weights = model.parameters().float().clone()  # FP32副本

for input, target in data:
    # 前向传播使用FP16
    fp16_weights = master_weights.half()
    output = model(input, fp16_weights)
    loss = criterion(output, target)
    
    # 反向传播
    loss.backward()
    
    # 更新FP32主权重
    for p, master_p in zip(model.parameters(), master_weights):
        master_p += lr * p.grad.float()  # 梯度转FP32后更新
    # 同步回FP16
    p.data = master_p.half()

方案二:损失缩放

梯度下溢可以通过放大损失来缓解:

scale_factor = 8192  # 放大因子

# 前向传播
output = model(input)
loss = criterion(output, target)

# 缩放损失
scaled_loss = loss * scale_factor

# 反向传播(梯度自动放大scale_factor倍)
scaled_loss.backward()

# 反缩放梯度
for p in model.parameters():
    p.grad = p.grad / scale_factor

# 更新权重
optimizer.step()

关键洞察:缩放因子可以很大(如32768),只要不导致梯度溢出。溢出可以被检测(出现inf/nan),检测到时可以跳过本次更新或减小缩放因子。

方案三:FP32累积

矩阵乘法的部分积应该用FP32累积:

$$C_{ij} = \sum_k A_{ik} \cdot B_{kj}$$

每个乘积$A_{ik} \cdot B_{kj}$可能在FP16中,但求和应该用FP32,最后再转回FP16存储。现代GPU的Tensor Core原生支持这种操作。

flowchart LR
    subgraph 混合精度训练流水线
        A[FP32主权重] -->|转换| B[FP16计算权重]
        B --> C[FP16前向传播]
        C --> D[计算损失]
        D --> E[损失缩放]
        E --> F[FP16反向传播]
        F --> G[梯度反缩放]
        G --> H[FP32权重更新]
        H --> A
    end

动态损失缩放

静态缩放因子需要手动调参。动态缩放自动调整:

class DynamicGradScaler:
    def __init__(self, init_scale=2**15, growth_factor=2.0, 
                 backoff_factor=0.5, growth_interval=2000):
        self.scale = init_scale
        self.growth_factor = growth_factor
        self.backoff_factor = backoff_factor
        self.growth_interval = growth_interval
        self._growth_tracker = 0
    
    def scale_loss(self, loss):
        return loss * self.scale
    
    def unscale_grads(self, optimizer):
        for p in optimizer.param_groups[0]['params']:
            if p.grad is not None:
                p.grad = p.grad / self.scale
    
    def update(self, found_inf):
        if found_inf:
            # 检测到溢出,减小缩放因子
            self.scale *= self.backoff_factor
            self._growth_tracker = 0
        else:
            # 正常更新,尝试增大缩放因子
            self._growth_tracker += 1
            if self._growth_tracker >= self.growth_interval:
                self.scale *= self.growth_factor
                self._growth_tracker = 0

BF16:免调参的解决方案

BF16的设计从根本上解决了FP16的问题:

特性 FP16 BF16
指数位 5位 8位
尾数位 10位 7位
动态范围 窄(与FP32差异大) 宽(与FP32相同)
精度
需要损失缩放 通常需要 通常不需要

BF16保留了FP32的动态范围,这意味着:

  • 梯度下溢风险大幅降低
  • 权重更新下溢风险大幅降低
  • 在大多数情况下不需要损失缩放

代价是精度降低(只有7位尾数),但深度学习的数值研究表明,这种精度损失对模型质量影响有限。

graph TB
    subgraph FP16训练
        A1[需要损失缩放] --> B1[需要FP32主权重]
        B1 --> C1[动态缩放调参]
        C1 --> D1[复杂度高]
    end
    subgraph BF16训练
        A2[通常不需要损失缩放] --> B2[FP32主权重推荐但非必需]
        B2 --> C2[无需缩放调参]
        C2 --> D2[复杂度低]
    end
# PyTorch中使用BF16
model = model.to(torch.bfloat16)

# BF16训练通常不需要GradScaler
# 但仍需要FP32主权重
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for batch in dataloader:
    optimizer.zero_grad()
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        output = model(batch)
        loss = criterion(output)
    loss.backward()
    optimizer.step()

Flash Attention:分块计算的数值挑战

Flash Attention通过分块计算减少显存访问,但分块Softmax带来了新的数值稳定性挑战。

分块Softmax的困难

标准Softmax需要全局归一化,但Flash Attention将输入分成小块处理。如何保证分块计算的数值稳定性?

关键洞察是在线Softmax算法。假设我们要计算:

$$\text{softmax}(x) = \frac{e^{x}}{\sum_j e^{x_j}}$$

将$x$分成两块$x^{(1)}$和$x^{(2)}$:

$$\sum_j e^{x_j} = \sum_{j \in \text{block}_1} e^{x_j} + \sum_{j \in \text{block}_2} e^{x_j}$$

但这在数值上不稳定。正确的方法是跟踪每块的最大值:

def online_softmax(x_blocks):
    """在线Softmax算法"""
    global_max = -float('inf')
    global_sum = 0.0
    
    for block in x_blocks:
        block_max = np.max(block)
        
        if global_max == -float('inf'):
            global_max = block_max
            global_sum = np.sum(np.exp(block - block_max))
        else:
            # 关键:需要重新缩放之前的累积和
            if block_max > global_max:
                scale = np.exp(global_max - block_max)
                global_sum = global_sum * scale + np.sum(np.exp(block - block_max))
                global_max = block_max
            else:
                scale = np.exp(block_max - global_max)
                global_sum = global_sum + np.sum(np.exp(block - global_max))
    
    return global_max + np.log(global_sum)
flowchart LR
    subgraph 在线Softmax
        A[块1] --> C[当前最大值 m1<br/>累积和 exp x-m1]
        B[块2] --> D[新最大值 m2<br/>新累积和]
        C --> E{m2 > m1?}
        D --> E
        E -->|是| F[重新缩放旧累积和<br/>×exp m1-m2]
        E -->|否| G[缩放新累积和<br/>×exp m2-m1]
        F --> H[合并累积和]
        G --> H
    end

Flash Attention的稳定实现

Flash Attention的核心是分块计算注意力,同时保持数值稳定性:

def flash_attention(Q, K, V, block_size=256):
    """简化的Flash Attention实现"""
    seq_len = Q.shape[0]
    output = np.zeros_like(Q)
    
    for i in range(0, seq_len, block_size):
        q_block = Q[i:i+block_size]
        block_max = np.full(block_size, -float('inf'))
        block_sum = np.zeros(block_size)
        block_output = np.zeros_like(q_block)
        
        for j in range(0, seq_len, block_size):
            k_block = K[j:j+block_size]
            v_block = V[j:j+block_size]
            
            # 计算当前块的注意力分数
            scores = np.dot(q_block, k_block.T)
            
            # 在线更新最大值和归一化因子
            new_max = np.maximum(block_max, np.max(scores, axis=1))
            scale = np.exp(block_max - new_max)
            scale_scores = np.exp(scores - new_max[:, None])
            
            block_sum = block_sum * scale + np.sum(scale_scores, axis=1)
            block_output = block_output * scale[:, None] + np.dot(scale_scores, v_block)
            block_max = new_max
        
        output[i:i+block_size] = block_output / block_sum[:, None]
    
    return output

关键点:

  1. 每块计算时减去当前最大值防止溢出
  2. 跨块合并时需要重新缩放之前的结果
  3. 最终归一化在所有块处理完后进行

梯度问题:深度网络的数学诅咒

梯度消失和爆炸是深度学习的经典问题,其根源在于反向传播的链式法则。

数学分析

考虑一个$L$层的深度网络,每层的激活函数为$\phi$。第$l$层的梯度为:

$$\frac{\partial L}{\partial W_l} = \frac{\partial L}{\partial a_L} \prod_{k=l+1}^{L} \frac{\partial a_k}{\partial a_{k-1}} \cdot \frac{\partial a_l}{\partial W_l}$$

关键在于连乘项$\prod_{k=l+1}^{L} \frac{\partial a_k}{\partial a_{k-1}}$。假设每层的导数近似为$\gamma$,则$L-l$层后的梯度大约为$\gamma^{L-l}$。

  • 当$|\gamma| < 1$时,$\gamma^{L-l} \to 0$(梯度消失)
  • 当$|\gamma| > 1$时,$\gamma^{L-l} \to \infty$(梯度爆炸)
graph LR
    subgraph 反向传播梯度流
        A[输出层<br/>梯度=1.0] --> B[第L-1层<br/>梯度×σ']
        B --> C[第L-2层<br/>梯度×σ']
        C --> D[...]
        D --> E[第1层<br/>梯度×σ'^L]
    end
    subgraph 梯度消失
        F["σ'=0.25<br/>10层后: 0.25^10≈10^-6"]
    end
    subgraph 梯度爆炸
        G["σ'>1<br/>可能发散"]
    end

Sigmoid的问题

Sigmoid函数$\sigma(x) = \frac{1}{1+e^{-x}}$的导数为:

$$\sigma'(x) = \sigma(x)(1-\sigma(x)) \leq 0.25$$

最大导数只有0.25。对于10层网络,梯度可能衰减到$0.25^{10} \approx 10^{-6}$,在FP16下直接变成0。

ReLU的改进

ReLU函数$f(x) = \max(0, x)$的导数为:

$$f'(x) = \begin{cases} 1 & x > 0 \\ 0 & x \leq 0 \end{cases}$$

当$x > 0$时,导数为1,不会衰减。这解决了梯度消失问题(在激活区域),但可能引入梯度爆炸。

梯度裁剪

梯度裁剪是解决梯度爆炸的常用技术:

def clip_grad_norm_(parameters, max_norm, norm_type=2):
    """梯度范数裁剪"""
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    
    parameters = [p for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    
    if norm_type == float('inf'):
        total_norm = max(p.grad.data.abs().max() for p in parameters)
    else:
        total_norm = 0
        for p in parameters:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm.item() ** norm_type
        total_norm = total_norm ** (1. / norm_type)
    
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            p.grad.data.mul_(clip_coef)
    
    return total_norm

两种常见策略:

  1. 按值裁剪:将梯度限制在$[-c, c]$
  2. 按范数裁剪:当梯度范数超过阈值时,缩放梯度使其范数等于阈值

Xavier与He初始化

合理的权重初始化可以从源头控制梯度问题。

Xavier初始化(适用于tanh/sigmoid):

$$W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in} + n_{out}}}\right)$$

设计目标:使每层的输入方差等于输出方差,保持信号强度。

He初始化(适用于ReLU):

$$W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in}}}\right)$$

考虑到ReLU会"杀死"一半的神经元,需要额外的$\sqrt{2}$补偿。

Loss Spike:Transformer训练的幽灵

大模型训练中常见一种现象:损失函数突然跳升几个数量级,然后可能恢复,也可能发散。这就是Loss Spike。

Loss Spike的成因

研究表明,Loss Spike通常由以下因素触发:

  1. 数据异常:训练数据中的异常样本可能产生极大的梯度
  2. 学习率过高:过大的学习率使优化器跳过局部最优,进入损失地形的高曲率区域
  3. 注意力分布坍塌:Softmax输出的极端分布(如[0, 0, …, 1, …0])导致梯度不稳定
  4. BatchNorm统计异常:小批量中的极端样本扭曲running statistics
graph TB
    subgraph Loss Spike触发因素
        A[数据异常样本] --> E[极大梯度]
        B[学习率过高] --> F[跳过最优区]
        C[注意力分布坍塌] --> G[梯度不稳定]
        D[BatchNorm统计异常] --> H[归一化失效]
    end
    E --> I[梯度范数爆炸]
    F --> I
    G --> I
    H --> I
    I --> J[Loss Spike]

检测与缓解

class LossSpikeHandler:
    def __init__(self, threshold=5.0, patience=3):
        self.threshold = threshold  # 相对于滑动平均的倍数
        self.patience = patience
        self.loss_history = []
        self.spike_count = 0
    
    def check_spike(self, loss):
        self.loss_history.append(loss)
        
        if len(self.loss_history) < 10:
            return False
        
        # 计算滑动平均和标准差
        recent = self.loss_history[-10:]
        mean = np.mean(recent)
        std = np.std(recent)
        
        # 检测异常
        is_spike = loss > mean + self.threshold * std
        
        if is_spike:
            self.spike_count += 1
            if self.spike_count >= self.patience:
                return True  # 需要干预
        else:
            self.spike_count = 0
        
        return False
    
    def intervene(self, optimizer, original_lr):
        """干预策略"""
        # 降低学习率
        for param_group in optimizer.param_groups:
            param_group['lr'] = original_lr * 0.1

ZClip:自适应梯度裁剪

2025年提出的ZClip算法提供了一种自适应的解决方案:

class ZClip:
    """自适应梯度裁剪,专门针对Loss Spike"""
    def __init__(self, z_threshold=3.0, min_grad_norm=1e-6):
        self.z_threshold = z_threshold
        self.grad_norm_history = []
        self.min_grad_norm = min_grad_norm
    
    def clip(self, parameters):
        # 计算当前梯度范数
        total_norm = 0
        for p in parameters:
            if p.grad is not None:
                total_norm += p.grad.data.norm() ** 2
        total_norm = total_norm ** 0.5
        
        self.grad_norm_history.append(total_norm)
        
        if len(self.grad_norm_history) < 10:
            return total_norm
        
        # 计算Z分数
        recent = self.grad_norm_history[-10:]
        mean = np.mean(recent)
        std = np.std(recent)
        
        z_score = (total_norm - mean) / (std + self.min_grad_norm)
        
        # 如果Z分数超过阈值,裁剪
        if z_score > self.z_threshold:
            clip_coef = mean / (total_norm + self.min_grad_norm)
            for p in parameters:
                if p.grad is not None:
                    p.grad.data.mul_(clip_coef)
            return mean
        
        return total_norm

PyTorch中的数值稳定性最佳实践

PyTorch提供了多种内置工具来处理数值稳定性问题。

自动混合精度(AMP)

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler(init_scale=2**15, growth_factor=2.0, 
                    backoff_factor=0.5, growth_interval=2000)

for epoch in range(num_epochs):
    for batch, target in dataloader:
        optimizer.zero_grad()
        
        with autocast(device_type='cuda', dtype=torch.float16):
            output = model(batch)
            loss = criterion(output, target)
        
        # 缩放损失并反向传播
        scaler.scale(loss).backward()
        
        # 反缩放梯度并裁剪
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # 更新权重
        scaler.step(optimizer)
        scaler.update()

数值稳定的层

PyTorch提供了数值稳定的组合层:

import torch.nn.functional as F

# 不要分开使用
logits = model(x)
probs = F.softmax(logits, dim=-1)
log_probs = torch.log(probs)  # 可能出现log(0)

# 使用组合版本
log_probs = F.log_softmax(logits, dim=-1)  # 数值稳定

# 交叉熵损失也内置了稳定的组合
loss = F.cross_entropy(logits, targets)  # 等价于 log_softmax + nll_loss

LayerNorm中的Epsilon

LayerNorm需要计算标准差,可能除以零。PyTorch内置了epsilon保护:

# LayerNorm实现
def layer_norm(x, normalized_shape, weight, bias, eps=1e-5):
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True, unbiased=False)
    x_norm = (x - mean) / torch.sqrt(var + eps)  # eps防止除零
    return weight * x_norm + bias

默认eps=1e-5在FP32下足够,但在FP16训练中可能需要增大:

# FP16训练可能需要更大的epsilon
layer_norm = nn.LayerNorm(hidden_size, eps=1e-6)  # 或更大

调试数值问题:从NaN到根源

当模型输出NaN时,系统性的调试流程至关重要。

检测NaN/Inf

def check_tensor_health(tensor, name="tensor"):
    """检查张量的数值健康状况"""
    has_nan = torch.isnan(tensor).any()
    has_inf = torch.isinf(tensor).any()
    
    if has_nan or has_inf:
        print(f"[ALERT] {name}: NaN={has_nan.item()}, Inf={has_inf.item()}")
        print(f"  Shape: {tensor.shape}")
        print(f"  Min: {tensor[~torch.isnan(tensor) & ~torch.isinf(tensor)].min() if (~torch.isnan(tensor) & ~torch.isinf(tensor)).any() else 'N/A'}")
        print(f"  Max: {tensor[~torch.isnan(tensor) & ~torch.isinf(tensor)].max() if (~torch.isnan(tensor) & ~torch.isinf(tensor)).any() else 'N/A'}")
        return False
    return True

# 使用钩子监控每层输出
def register_health_checks(model):
    for name, module in model.named_modules():
        module.register_forward_hook(
            lambda m, i, o, n=name: check_tensor_health(o, n) if o is not None else None
        )
flowchart TB
    subgraph NaN调试流程
        A[检测到NaN] --> B{检查输入数据}
        B -->|有NaN| C[修复数据预处理]
        B -->|无NaN| D{检查第一层输出}
        D -->|有NaN| E[检查初始化]
        D -->|无NaN| F{逐层检查}
        F -->|定位到某层| G{检查该层}
        G --> H[学习率过高?]
        G --> I[损失缩放?]
        G --> J[归一化eps?]
        G --> K[除法保护?]
    end

常见原因清单

症状 可能原因 解决方案
训练初期就出现NaN 学习率过高 降低学习率1-2个数量级
混合精度训练出现NaN 损失缩放不当 使用动态缩放或BF16
特定批次触发NaN 数据问题 检查数据预处理,添加NaN过滤
深层网络出现NaN LayerNorm epsilon过小 增大eps参数
Softmax后出现NaN 除法运算无保护 添加epsilon保护
第一轮迭代就出现NaN 权重初始化不当 使用Xavier或He初始化

FP8:数值稳定性的新前沿

FP8(8位浮点数)正在成为LLM训练的新标准,但带来了更严峻的数值挑战。

FP8格式

FP8有两种主流格式:

格式 符号位 指数位 尾数位 范围
E4M3 1 4 3 $[2^{-6}, 448]$
E5M2 1 5 2 $[2^{-14}, 57344]$

E4M3精度更高但范围更窄,适合存储激活和权重;E5M2范围更广但精度更低,适合存储梯度。

block-beta
    columns 4
    block:e4m3:4
        columns 8
        S1["S"]
        E1["E"]
        E2["E"]
        E3["E"]
        E4["E"]
        M1["M"]
        M2["M"]
        M3["M"]
    end
    block:e5m2:4
        columns 8
        S2["S"]
        E5["E"]
        E6["E"]
        E7["E"]
        E8["E"]
        E9["E"]
        M4["M"]
        M5["M"]
    end

FP8训练的核心挑战

  1. 异常特征(Outlier Features):Transformer训练中会出现激活值比其他值大100-1000倍的异常特征,FP8无法同时表示这种动态范围

  2. 量化误差累积:每次FP8转换都会损失精度,多层累积后可能导致发散

  3. 梯度表示:梯度通常有更大的动态范围和更多的小值

解决方案:动态缩放

class FP8DynamicScaler:
    """FP8动态缩放"""
    def __init__(self, margin=3):
        self.margin = margin  # 偏移量,避免溢出
        self.scale = 1.0
    
    def compute_scale(self, tensor, fp8_max=448.0):
        """计算最优缩放因子"""
        tensor_max = tensor.abs().max()
        if tensor_max == 0:
            return self.scale
        
        # 计算缩放因子使最大值接近FP8最大值
        target_max = fp8_max / (2 ** self.margin)
        new_scale = target_max / tensor_max
        
        # 平滑更新
        self.scale = 0.9 * self.scale + 0.1 * new_scale
        return self.scale
    
    def quantize(self, tensor):
        scale = self.compute_scale(tensor)
        return (tensor * scale).to(torch.int8).to(torch.float32) / scale

实践建议:构建数值稳定的训练流水线

综合以上分析,以下是构建数值稳定训练流水线的检查清单:

flowchart TB
    subgraph 数据阶段
        A1[检查极端值/NaN] --> A2[数据归一化]
        A2 --> A3[数据增强验证]
    end
    subgraph 模型阶段
        B1[权重初始化<br/>He/Xavier] --> B2[LayerNorm eps]
        B2 --> B3[使用稳定API<br/>log_softmax]
    end
    subgraph 训练阶段
        C1[混合精度<br/>优先BF16] --> C2[梯度裁剪]
        C2 --> C3[监控梯度/损失]
    end
    subgraph 调试阶段
        D1[健康检查钩子] --> D2[梯度统计日志]
        D2 --> D3[确定性模式复现]
    end
    A3 --> B1
    B3 --> C1
    C3 --> D1

数据阶段

  • 检查输入数据是否包含极端值或NaN
  • 确保数据归一化到合理范围(如[-1, 1]或[0, 1])
  • 添加数据增强时注意不产生数值异常

模型阶段

  • 使用He初始化(ReLU)或Xavier初始化(tanh/sigmoid)
  • LayerNorm/BatchNorm使用适当的epsilon
  • 避免手动实现softmax,使用框架的log_softmax

训练阶段

  • 使用混合精度训练(优先BF16)
  • 启用梯度裁剪(norm clipping)
  • 监控梯度范数和损失曲线

调试阶段

  • 注册张量健康检查钩子
  • 记录每层的梯度统计
  • 使用确定性模式复现问题

数值稳定性是深度学习系统工程的基石。理解浮点数的物理限制、掌握Log-Sum-Exp等数值技巧、合理配置混合精度训练,这些都是现代深度学习工程师的必备技能。当代价是72小时的训练成果化为NaN时,预防永远比修复更重要。


参考资料

  1. Micikevicius, P., et al. “Mixed Precision Training.” ICLR 2018.
  2. Zhao, R., et al. “Reducing Underflow in Mixed Precision Training by Gradient Scaling.” IJCAI 2020.
  3. NVIDIA. “Training With Mixed Precision.” NVIDIA Developer Documentation.
  4. Kahan, W. “IEEE Standard 754 for Binary Floating-Point Arithmetic.” 1997.
  5. Blanchard, P., et al. “Mixed Precision Block Fused Multiply-Add: Error Analysis.” 2019.
  6. Dao, T., et al. “FlashAttention: Fast and Memory-Efficient Exact Attention.” NeurIPS 2022.
  7. Peng, B., et al. “FP8-LM: Training FP8 Large Language Models.” arXiv 2023.
  8. “Numerical Stability in Flash Attention.” jarbus.net, 2023.
  9. “Numerically Stable Softmax and Cross Entropy.” Jay Mody, 2022.
  10. Gundersen, G. “The Log-Sum-Exp Trick.” 2020.
  11. “Understanding the FP64, FP32, FP16, BFLOAT16, TF32, FP8 Formats.” jeffreytse.net, 2024.
  12. “BF16 vs FP16: A Comparison of Performance and Efficiency.” Beam Blog, 2025.
  13. “An Empirical Study on Numerical Bugs in Deep Learning Programs.” ACM Digital Library.
  14. “Loss spikes in training: causes, detection, and mitigations.” Medium, 2026.
  15. “ZClip: Adaptive Spike Mitigation for LLM Pre-Training.” arXiv 2025.
  16. “Enhancing LLM Pretraining Stability via Adaptive Gradient Clipping.” arXiv 2026.