引言:一个被误读了四十年的算法

如果你问一个机器学习从业者:“反向传播算法是谁发明的?“大多数人的回答会是:“Geoffrey Hinton和他的合作者在1986年提出的。“这个答案听起来理所当然——毕竟,Hinton被称为"深度学习之父”,而那篇发表在《Nature》上的论文《Learning representations by back-propagating errors》至今仍被奉为经典。

但这个答案是错的。

反向传播算法并不是在1980年代发明的,也不是Hinton发明的,甚至不是在机器学习领域发明的。它是一套源于1960年代控制理论的优化方法,在1970年被芬兰程序员Seppo Linnainmaa正式算法化,然后在1974年被Paul Werbos首次应用于神经网络训练,最后才在1986年被Rumelhart、Hinton和Williams以一种更易于理解的方式重新包装并推广。

理解这段历史,不仅仅是为了还原真相。更重要的是,当我们把反向传播放在自动微分(Automatic Differentiation,AD)的大框架下审视时,会发现它只是冰山一角——一个被称为"反向模式自动微分"的特例。而自动微分的世界,远比我们想象的要广阔。

第一部分:从链式法则到自动微分

链式法则:三百年的数学根基

一切要从1676年说起。那一年,Gottfried Wilhelm Leibniz在一篇备忘录中首次提出了链式法则。如果你还记得大学微积分课程,链式法则告诉我们如何求复合函数的导数:

$$\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}$$

这个看似简单的公式,是整个现代深度学习的数学基石。没有它,就没有反向传播,也就没有今天的大语言模型。

链式法则的核心思想是:如果一个变量通过一串中间变量间接依赖于另一个变量,我们可以通过将每一步的局部导数相乘来得到整体的导数。这个"分解-组合"的思想,正是自动微分的灵魂所在。

三种微分方法的对比

在深入自动微分之前,我们需要先厘清它与另外两种微分方法的区别。

符号微分(Symbolic Differentiation)

符号微分是我们在数学课上学到的那种方法:给定一个数学表达式,通过应用一系列求导规则(如幂法则、乘积法则、链式法则等)来推导出一个新的、表示导数的符号表达式。

# 符号微分示例
# 输入: f(x) = x^2 + 3x + 1
# 输出: f'(x) = 2x + 3

符号微分的优点是精确,可以得到解析解。但它的致命弱点是"表达式膨胀”——当函数变得复杂时,导数表达式可能呈指数级增长。想象一下对深度神经网络应用符号微分:你需要显式地写出一个包含数十亿次矩阵乘法、激活函数和损失函数的庞大导数表达式。这在计算上是不可行的。

数值微分(Numerical Differentiation)

数值微分采用最原始的方法:用差分近似导数。

$$f'(x) \approx \frac{f(x + \epsilon) - f(x)}{\epsilon}$$

这种方法简单直接,不需要知道函数的具体形式,只需要能够计算函数值。但它的精度受限于浮点数精度——$\epsilon$太大,近似误差大;$\epsilon$太小,舍入误差大。更糟糕的是,对于高维函数,数值微分的计算成本随维度线性增长。对于有数十亿参数的模型,这是完全不可接受的。

自动微分(Automatic Differentiation)

自动微分找到了一条巧妙的中间路线:它不像符号微分那样推导完整的符号表达式,也不像数值微分那样用有限差分近似。相反,它将复杂函数分解为一系列基本操作(加减乘除、指数、对数等),然后利用链式法则逐层计算导数。

关键洞察:自动微分不是近似——它计算的是精确的导数值(在浮点精度范围内)。它只是避免了存储和操作庞大的符号表达式。

第二部分:自动微分的两种模式

自动微分有两种基本模式:前向模式(Forward Mode)和反向模式(Reverse Mode)。这两种模式在数学上等价,但在计算效率上有着天壤之别。理解它们的差异,是理解为什么深度学习选择"反向传播"的关键。

前向模式:从输入到输出

前向模式自动微分的基本思想是:在计算函数值的同时,计算导数值。这听起来很自然——当我们从输入一步步计算到输出时,顺便把每一步的导数也算出来。

flowchart LR
    subgraph Forward Mode
        direction LR
        x["x (输入)"] --> v1["v1 = x + 2"]
        v1 --> v2["v2 = v1 * 3"]
        v2 --> y["y = v2^2 (输出)"]
        
        x -.->|"dx/dx = 1"| d1["dv1/dx = 1"]
        d1 -->|"d(v1*3)/dv1 = 3"| d2["dv2/dx = 3"]
        d2 -->|"d(v2^2)/dv2 = 2v2"| dy["dy/dx = 2v2 * 3"]
    end
    
    style x fill:#e1f5fe
    style y fill:#c8e6c9
    style dy fill:#fff9c4

前向模式的核心数据结构是对偶数(Dual Numbers)。对偶数将每个数表示为 $a + b\epsilon$,其中 $\epsilon$ 是一个满足 $\epsilon^2 = 0$ 的特殊符号(类似于虚数单位 $i$ 满足 $i^2 = -1$)。

对偶数的奇妙之处在于:当我们对对偶数进行运算时,结果中的 $\epsilon$ 系数正好是导数!

# 对偶数实现
class DualNumber:
    def __init__(self, real, dual):
        self.real = real  # 实部:函数值
        self.dual = dual  # 对偶部:导数值
    
    def __add__(self, other):
        return DualNumber(self.real + other.real, 
                          self.dual + other.dual)
    
    def __mul__(self, other):
        # (a + bε)(c + dε) = ac + (ad + bc)ε + bd*ε²
        # 由于 ε² = 0,结果为 ac + (ad + bc)ε
        return DualNumber(self.real * other.real,
                          self.real * other.dual + self.dual * other.real)
    
    def __pow__(self, n):
        # (a + bε)^n = a^n + n*a^(n-1)*b*ε
        return DualNumber(self.real ** n,
                          n * self.real ** (n-1) * self.dual)

# 使用对偶数计算 f(x) = x^2 + 3x + 1 在 x=2 处的值和导数
x = DualNumber(2, 1)  # x = 2 + 1*ε,导数初始为1
y = x ** 2 + x * DualNumber(3, 0) + DualNumber(1, 0)
print(f"f(2) = {y.real}, f'(2) = {y.dual}")
# 输出: f(2) = 11, f'(2) = 7

前向模式的计算复杂度是多少?假设函数有 $n$ 个输入和 $m$ 个输出,如果我们想计算函数对所有输入的梯度,需要进行 $n$ 次前向传播——每次将一个输入的对偶部设为1,其他设为0。

关键结论:前向模式的复杂度为 $O(n)$,其中 $n$ 是输入维度。

反向模式:从输出到输入

反向模式采取了截然不同的策略:先完整地执行前向传播,记录所有中间结果,然后从输出向输入反向传播梯度。

flowchart TB
    subgraph Reverse Mode - Forward Pass
        direction TB
        x["x (输入)"] --> v1["v1 = x + 2"]
        v1 --> v2["v2 = v1 * 3"]
        v2 --> y["y = v2^2 (输出)"]
    end
    
    subgraph Reverse Mode - Backward Pass
        direction BT
        dy["dy/dy = 1"] --> dv2["dv2 = dy/dy * 2v2 = 2v2"]
        dv2 --> dv1["dv1 = dv2 * 3 = 6v2"]
        dv1 --> dx["dx = dv1 * 1 = 6v2"]
    end
    
    style x fill:#e1f5fe
    style y fill:#c8e6c9
    style dy fill:#fff9c4

反向模式的执行过程分为两个阶段:

第一阶段:前向传播

  1. 从输入开始,逐层计算每个中间变量的值
  2. 构建计算图(Computational Graph),记录所有操作及其依赖关系
  3. 保存计算梯度所需的中间结果

第二阶段:反向传播

  1. 从输出开始,初始化梯度为1($\frac{\partial L}{\partial L} = 1$)
  2. 按拓扑逆序遍历计算图
  3. 对每个节点,应用链式法则:$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x}$

反向模式的计算复杂度是多少?无论输入有多少维度,只需要一次反向传播就可以得到输出对所有输入的梯度。

关键结论:反向模式的复杂度为 $O(m)$,其中 $m$ 是输出维度。

为什么深度学习选择反向模式?

现在,关键的问题来了:为什么深度学习框架(PyTorch、TensorFlow、JAX)都选择反向模式自动微分?

让我们算一笔账:

  • 典型的神经网络:输入维度(参数量)可能达到数十亿,输出维度(损失函数)只有一个标量
  • 前向模式:$O(n)$ 次传播,对于十亿参数的模型,需要十亿次前向传播——完全不可行
  • 反向模式:$O(m) = O(1)$ 次传播,无论参数多少,只需一次反向传播

这就是反向传播(反向模式自动微分)统治深度学习的根本原因:它在数学上完美契合了神经网络的"多输入、单输出"结构。

graph LR
    subgraph 复杂度比较
        direction LR
        A["输入维度 n"] --> B["前向模式: O(n)"]
        A --> C["反向模式: O(m)"]
        D["输出维度 m"] --> B
        D --> C
    end
    
    E["神经网络场景"] --> F["n >> m"]
    F --> G["反向模式胜出!"]
    
    style G fill:#c8e6c9

当然,反向模式也有代价:它需要存储前向传播的所有中间结果,这带来了巨大的内存开销。这也是为什么大模型训练需要大量的显存——不仅仅是存储参数,更要存储梯度计算所需的中间激活值。

第三部分:计算图的构建与遍历

什么是计算图?

计算图(Computational Graph)是自动微分的核心数据结构。它是一个有向无环图(DAG),其中:

  • 节点(Nodes):表示操作或变量
  • 边(Edges):表示数据依赖关系
graph BT
    subgraph 简单计算图示例: z = (x + y) * (x - y)
        x["x (leaf)"]
        y["y (leaf)"]
        x --> add["+"]
        y --> add
        x --> sub["-"]
        y --> sub
        add --> mul["*"]
        sub --> mul
        mul --> z["z (output)"]
    end
    
    style x fill:#e1f5fe
    style y fill:#e1f5fe
    style z fill:#c8e6c9

在PyTorch中,计算图是动态构建的——每次前向传播都会创建一个新的计算图。这种"定义即运行”(Define-by-Run)的设计,使得PyTorch能够自然地支持Python的控制流(循环、条件分支等),这是它相比早期TensorFlow静态图模式的主要优势。

PyTorch如何构建计算图?

让我们深入PyTorch的内部实现。当你执行一个操作时,PyTorch会创建一个Node对象(通常称为grad_fn),它知道如何计算该操作的梯度。

// PyTorch C++核心数据结构(简化版)
struct Node {
    std::vector<Edge> next_edges_;  // 指向父节点的边
    uint64_t sequence_nr_;          // 用于拓扑排序的序列号
    
    // 核心方法:给定输出梯度,计算输入梯度
    virtual variable_list apply(variable_list&& inputs) = 0;
};

struct Edge {
    std::shared_ptr<Node> function;  // 父节点
    uint32_t input_nr;               // 父节点的第几个输入
};

当你在Python中调用y = x * 3时,PyTorch在内部做了以下事情:

# 伪代码:PyTorch内部操作记录逻辑
def mul(x, other):
    # 1. 执行实际的数学运算
    result = tensor_mul(x, other)
    
    # 2. 如果任一输入需要梯度,创建反向节点
    if x.requires_grad or other.requires_grad:
        grad_fn = MulBackward0()
        
        # 3. 保存计算梯度所需的中间值
        grad_fn.save_for_backward([x, other])
        
        # 4. 建立边连接到输入节点
        grad_fn.set_next_edges([
            Edge(x.grad_fn, x.output_nr()),
            Edge(other.grad_fn, other.output_nr())
        ])
        
        # 5. 将grad_fn附加到输出
        result.grad_fn = grad_fn
        result.requires_grad = True
    
    return result

拓扑排序与反向遍历

当调用.backward()时,PyTorch需要确定反向传播的执行顺序。这通过拓扑排序实现:从输出节点开始,按照依赖关系反向排列所有节点。

def backward(root, grad_output):
    # 1. 拓扑排序
    nodes = topological_sort(root.grad_fn)
    
    # 2. 初始化梯度映射
    grad_map = {root.grad_fn: grad_output}
    
    # 3. 按拓扑逆序执行
    for node in nodes:
        # 获取该节点累积的梯度
        grad_inputs = grad_map[node]
        
        # 调用节点的反向函数
        grad_outputs = node.apply(grad_inputs)
        
        # 将梯度传播到父节点
        for i, edge in enumerate(node.next_edges_):
            if edge.function is not None:
                grad_map[edge.function] += grad_outputs[i]

SavedVariable:内存优化的艺术

反向传播需要前向传播的中间结果。PyTorch通过SavedVariable类来管理这些保存的值。

class SavedVariable {
private:
    at::Tensor data_;           // 实际的张量数据
    uint32_t version_counter_;  // 版本计数器,检测原地修改
    
public:
    // 前向时保存
    SavedVariable(const Variable& variable, bool is_output);
    
    // 反向时解包
    Variable unpack() const;
};

不同操作需要保存的内容不同:

操作 前向计算 反向梯度 需要保存的内容
平方 $y = x^2$ $\frac{\partial L}{\partial x} = 2x \cdot \frac{\partial L}{\partial y}$ 输入 $x$
指数 $y = e^x$ $\frac{\partial L}{\partial x} = y \cdot \frac{\partial L}{\partial y}$ 输出 $y$(比输入更小)
矩阵乘法 $Y = XW$ $\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W^T$ 输入 $X$ 和 $W$
ReLU $y = \max(0, x)$ $\frac{\partial L}{\partial x} = \mathbb{1}_{x>0} \cdot \frac{\partial L}{\partial y}$ 布尔掩码(非常小)

PyTorch会智能地选择保存输出而非输入(当输出更小时),或者只保存元数据(如转置操作的步长信息)而非实际数据。

第四部分:梯度累积与内存优化

梯度累积的两种场景

梯度累积是自动微分引擎的一项重要功能。它发生在两种不同的场景中:

场景一:多路径依赖

当一个变量在计算图中被多次使用时,来自不同路径的梯度会自动累加。

x = torch.tensor(2.0, requires_grad=True)
y = x * 2      # 第一条路径
z = x + 3      # 第二条路径
loss = y + z

loss.backward()
print(x.grad)  # tensor(3.) = d(loss)/dy * dy/dx + d(loss)/dz * dz/dx
               # = 1 * 2 + 1 * 1 = 3
graph LR
    x["x"] -->|"dy/dx = 2"| y["y = x * 2"]
    x -->|"dz/dx = 1"| z["z = x + 3"]
    y -->|"dloss/dy = 1"| loss["loss = y + z"]
    z -->|"dloss/dz = 1"| loss
    
    style x fill:#e1f5fe
    style loss fill:#c8e6c9

场景二:小批量梯度累积

在显存有限的情况下,我们可以通过累积多个小批量的梯度来模拟大批量训练。

# 模拟批量大小 128,实际每次处理 32
accumulation_steps = 4
optimizer.zero_grad()

for i, (x, y) in enumerate(dataloader):  # 实际批量大小 = 32
    output = model(x)
    loss = criterion(output, y) / accumulation_steps  # 关键:缩放损失
    
    loss.backward()  # 梯度累积在 parameter.grad 中
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

为什么要除以accumulation_steps?因为梯度是累加的,四次反向传播后梯度会是单次传播的四倍。除以步数确保等效于单次大批量计算的平均梯度。

梯度检查点:用时间换空间

对于超大模型,即使有GPU显存也可能不够用。梯度检查点(Gradient Checkpointing,也称激活检查点)是一种内存优化技术:在前向传播时不保存所有中间激活值,而是在反向传播时重新计算它们。

graph LR
    subgraph 无检查点
        A1["Layer 1"] --> A2["Layer 2 (保存)"]
        A2 --> A3["Layer 3 (保存)"]
        A3 --> A4["Layer 4 (保存)"]
        A4 --> A5["输出"]
    end
    
    subgraph 有检查点
        B1["Layer 1"] --> B2["Layer 2 (丢弃)"]
        B2 --> B3["Layer 3 (保存)"]
        B3 --> B4["Layer 4 (丢弃)"]
        B4 --> B5["输出"]
    end
    
    style A2 fill:#ffcdd2
    style A3 fill:#ffcdd2
    style A4 fill:#ffcdd2
    style B3 fill:#ffcdd2
from torch.utils.checkpoint import checkpoint

class DeepModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(1000, 1000) for _ in range(100)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            # 使用检查点包装每一层
            x = checkpoint(layer, x)
        return x

# 内存使用:从 ~400MB 降到 ~40MB
# 计算开销:增加约 20-30%

梯度检查点的核心思想:在前向传播时只保存"检查点”(每隔几层保存一次),反向传播时从最近的检查点重新计算被丢弃的激活值。这是一种"用计算换内存"的经典权衡。

第五部分:数值稳定性与常见陷阱

梯度消失与梯度爆炸

反向传播并非总是顺利的。当网络很深时,梯度可能会遇到两个极端问题:

梯度消失(Vanishing Gradients)

当激活函数的导数小于1时,经过多层反向传播后,梯度会呈指数级衰减。以Sigmoid函数为例:

$$\sigma(x) = \frac{1}{1 + e^{-x}}$$$$\sigma'(x) = \sigma(x)(1 - \sigma(x)) \leq 0.25$$

对于100层的网络,梯度最多会被乘以 $0.25^{100}$,这是一个几乎为零的数。结果就是:靠近输入的层几乎无法学习。

graph LR
    subgraph 梯度消失示意
        L1["Layer 1"] -->|"×0.25"| L2["Layer 2"]
        L2 -->|"×0.25"| L3["Layer 3"]
        L3 -->|"×0.25"| L4["..."]
        L4 -->|"×0.25^n ≈ 0"| LN["Layer n"]
    end
    
    style LN fill:#ffcdd2

梯度爆炸(Exploding Gradients)

相反,如果权重初始化过大,梯度会呈指数级增长,导致权重更新幅度过大,训练发散。

$$w_{t+1} = w_t - \eta \frac{\partial L}{\partial w}$$

当 $\frac{\partial L}{\partial w}$ 非常大时,$w_{t+1}$ 可能跳到一个完全不同的区域,甚至变成NaN

解决方案一:权重初始化

良好的权重初始化可以控制前向传播时的激活值方差和反向传播时的梯度方差。

Xavier初始化(适用于Sigmoid、Tanh):

$$W \sim \mathcal{N}\left(0, \frac{2}{n_{in} + n_{out}}\right)$$

Kaiming初始化(适用于ReLU):

$$W \sim \mathcal{N}\left(0, \frac{2}{n_{in}}\right)$$

解决方案二:激活函数选择

ReLU及其变体解决了Sigmoid的梯度消失问题:

# ReLU: max(0, x)
# 导数: 1 (x > 0), 0 (x ≤ 0)
# 问题: 负区间神经元"死亡"

# Leaky ReLU: max(αx, x), α通常为0.01
# 解决了"死亡ReLU"问题

# GELU: x * Φ(x), 其中Φ是标准正态分布的CDF
# 在Transformer中广泛使用,比ReLU更平滑

解决方案三:梯度裁剪

对于梯度爆炸,梯度裁剪是一种简单有效的方案:

# PyTorch中的梯度裁剪
loss.backward()

# 方式1:按范数裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 方式2:按值裁剪
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)

optimizer.step()

按范数裁剪的数学原理:

$$g_{clipped} = g \cdot \min\left(1, \frac{\text{max\_norm}}{\|g\|}\right)$$

当梯度范数超过阈值时,将梯度缩放到阈值大小,保持梯度方向不变。

版本计数器:检测非法原地修改

PyTorch有一个常被忽视但至关重要的机制:版本计数器(Version Counter)。它用于检测可能导致梯度计算错误的非法原地修改。

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * 2  # y的反向计算需要x的值

# 危险!原地修改x
x.add_(1)  # x现在是 [2.0, 3.0, 4.0]

try:
    y.backward(torch.ones_like(y))
except RuntimeError as e:
    print(e)
    # RuntimeError: one of the variables needed for gradient 
    # computation has been modified by an inplace operation

每个张量都有一个版本号,每当发生原地修改时版本号递增。SavedVariable在保存时会记录版本号,在解包时检查版本是否匹配。如果不匹配,说明张量已被修改,梯度计算将产生错误结果——PyTorch选择抛出异常而非返回错误答案。

第六部分:现代框架的自动微分实现

PyTorch:定义即运行

PyTorch采用的是"定义即运行"(Define-by-Run)的动态图模式。计算图在前向传播时动态构建,每一步Python代码都会立即执行并更新图结构。

# PyTorch动态图示例
def forward(x, condition):
    result = x
    if condition:  # 运行时决定分支
        result = result * 2
    else:
        result = result + 1
    return result.sum()

# 不同的condition值产生不同的计算图
x = torch.randn(3, requires_grad=True)

# 第一次前向:condition=True,图有乘法节点
y1 = forward(x, True)
y1.backward()
print(x.grad)  # tensor([2., 2., 2.])

x.grad.zero_()

# 第二次前向:condition=False,图有加法节点
y2 = forward(x, False)
y2.backward()
print(x.grad)  # tensor([1., 1., 1.])

这种设计的优点是:

  1. 调试友好:可以使用Python调试器逐行检查
  2. 控制流自然:循环和条件分支就是普通的Python代码
  3. 灵活性强:每次前向传播可以有不同的图结构

缺点是:

  1. 性能开销:动态构建图有一定的开销
  2. 优化受限:难以进行全局图优化

JAX:函数式变换

JAX采取了完全不同的方法:它不是记录操作序列,而是对纯Python函数进行变换。

import jax
import jax.numpy as jnp

def loss_fn(params, x, y):
    predictions = jnp.dot(x, params)
    return jnp.mean((predictions - y) ** 2)

# grad是一个函数变换器
grad_fn = jax.grad(loss_fn)

# 计算梯度
gradients = grad_fn(params, x, y)

JAX的grad是一个高阶函数:它接受一个函数,返回计算其梯度的新函数。这种函数式风格使得JAX能够进行强大的自动向量化、JIT编译和自动微分。

TensorFlow 2.x:两者兼得

TensorFlow 2.x引入了tf.GradientTape,结合了动态图的灵活性和静态图的优化能力:

import tensorflow as tf

x = tf.Variable(3.0)

with tf.GradientTape() as tape:
    y = x ** 2 + 2 * x + 1

grad = tape.gradient(y, x)
print(grad)  # tf.Tensor(8.0, shape=(), dtype=float32)

GradientTape像一个录音机,记录with块内的所有操作。这种方式既保持了灵活性,又为后续的编译优化留下了空间。

第七部分:从理论到实践——一个完整的反向传播追踪

让我们通过一个完整的例子来追踪PyTorch中的反向传播全过程:

import torch
import torch.nn as nn

# 定义一个简单的网络
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(4, 1)
    
    def forward(self, x):
        h = torch.relu(self.fc1(x))
        return self.fc2(h)

model = SimpleNet()
criterion = nn.MSELoss()

# 训练数据
x = torch.randn(1, 3)
target = torch.randn(1, 1)

# ===== 前向传播 =====
output = model(x)
loss = criterion(output, target)

# 查看计算图结构
print(f"Loss grad_fn: {loss.grad_fn}")
# <MseLossBackward0>

print(f"Output grad_fn: {loss.grad_fn.next_functions[0][0]}")
# <AddmmBackward0> (fc2层的矩阵乘法+偏置加法)

# ===== 反向传播 =====
loss.backward()

# 检查梯度
for name, param in model.named_parameters():
    print(f"{name}: grad shape = {param.grad.shape}")
# fc1.weight: grad shape = torch.Size([4, 3])
# fc1.bias: grad shape = torch.Size([4])
# fc2.weight: grad shape = torch.Size([1, 4])
# fc2.bias: grad shape = torch.Size([1])

反向传播的执行顺序:

graph BT
    Loss["Loss (MSELoss)"] -->|grad=1| G1["MseLossBackward0"]
    G1 -->|grad=2(output-target)| G2["AddmmBackward0 (fc2)"]
    G2 -->|grad=dL/dh| G3["ReluBackward0"]
    G3 -->|grad=dL/dh * mask| G4["AddmmBackward0 (fc1)"]
    G4 -->|grad=dL/dx| X["Input (no grad)"]
    
    G2 -.->|dL/dW2| W2["fc2.weight.grad"]
    G2 -.->|dL/db2| B2["fc2.bias.grad"]
    G4 -.->|dL/dW1| W1["fc1.weight.grad"]
    G4 -.->|dL/db1| B1["fc1.bias.grad"]
    
    style Loss fill:#c8e6c9
    style W2 fill:#fff9c4
    style B2 fill:#fff9c4
    style W1 fill:#fff9c4
    style B1 fill:#fff9c4

结语:站在巨人的肩膀上

当我们调用loss.backward()时,我们在使用一套跨越六十年的智慧结晶:

  • 1961年,Arthur Bryson在控制理论中发展了伴随方法
  • 1962年,Stuart Dreyfus明确应用反向链式法则
  • 1970年,Seppo Linnainmaa形式化了反向模式自动微分
  • 1974年,Paul Werbos首次将反向模式应用于神经网络
  • 1986年,Rumelhart、Hinton和Williams让这套方法广为人知

今天,当我们训练一个拥有数千亿参数的大语言模型时,底层的数学原理与六十年前并无二致。但正是这些看似简单的链式法则迭代,在数十亿次重复后,让机器学会了理解语言、生成图像、甚至进行推理。

自动微分不仅是一个技术工具,更是一种思维方式的体现:将复杂问题分解为简单步骤,然后系统地组合结果。这种"分而治之"的思想,从Leibniz的时代延续至今,仍然是人类应对复杂性的最有力武器。

理解自动微分,就是理解现代机器学习的基石。下次当你看到模型损失曲线下降时,不妨想一想:在这背后,是数不清的链式法则在默默工作,将误差信号从输出一层层传回输入,推动着参数向更好的方向移动。这就是学习的本质——在数学上,它叫做梯度下降;在哲学上,它叫做反馈与修正。

参考资料

  1. Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning representations by back-propagating errors. Nature, 323(6088), 533-536.

  2. Linnainmaa, S. (1970). The representation of the cumulative rounding error of an algorithm as a Taylor expansion of the local rounding errors. Master’s Thesis, University of Helsinki.

  3. Werbos, P. J. (1974). Beyond regression: New tools for prediction and analysis in the behavioral sciences. PhD Thesis, Harvard University.

  4. Baydin, A. G., Pearlmutter, B. A., Radul, A. A., & Siskind, J. M. (2018). Automatic differentiation in machine learning: a survey. Journal of Machine Learning Research, 18(153), 1-43.

  5. Paszke, A., et al. (2019). PyTorch: An imperative style, high-performance deep learning library. Advances in Neural Information Processing Systems, 32.

  6. Bradbury, J., et al. (2018). JAX: composable transformations of Python+NumPy programs.

  7. Griewank, A., & Walther, A. (2008). Evaluating derivatives: principles and techniques of algorithmic differentiation. SIAM.

  8. Schmidhuber, J. (2015). Deep learning in neural networks: An overview. Neural Networks, 61, 85-117.

  9. He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. ICCV.

  10. Glorot, X., & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks. AISTATS.

  11. Chen, T., et al. (2016). Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174.

  12. Griewank, A. (1992). Achieving logarithmic growth of temporal and spatial complexity in reverse automatic differentiation. Optimization Methods and Software, 1(1), 35-54.

  13. Dreyfus, S. E. (1962). The numerical solution of variational problems. Journal of Mathematical Analysis and Applications, 5(1), 30-45.

  14. Bryson, A. E., & Denham, W. F. (1961). A steepest-ascent method for solving optimum programming problems. Journal of Applied Mechanics, 29(2), 247-257.

  15. Leibniz, G. W. (1676). Memoir on the chain rule. (Historical document).

  16. Clifford, W. K. (1873). Preliminary sketch of biquaternions. Proceedings of the London Mathematical Society, 1(1), 381-395.

  17. Pascanu, R., Mikolov, T., & Bengio, Y. (2013). On the difficulty of training recurrent neural networks. ICML.

  18. Bengio, Y., Simard, P., & Frasconi, P. (1994). Learning long-term dependencies with gradient descent is difficult. IEEE Transactions on Neural Networks, 5(2), 157-166.

  19. Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep learning. MIT Press.

  20. Abadi, M., et al. (2016). TensorFlow: A system for large-scale machine learning. OSDI.

  21. Roberts, D. A. (2022). SGD implicitly regularizes generalization error. NeurIPS.

  22. Pearlmutter, B. A. (1994). Fast exact multiplication by the Hessian. Neural Computation, 6(1), 147-160.

  23. Innes, M., et al. (2018). A differentiable programming system to bridge machine learning and scientific computing. arXiv preprint arXiv:1907.07592.

  24. Raaijmakers, W. (2026). Optimize PyTorch training with the autograd engine. Red Hat Developer.

  25. PyTorch Documentation. (2024). torch.autograd. https://docs.pytorch.org/docs/stable/autograd.html

  26. Chen, R. T. Q., et al. (2018). Neural ordinary differential equations. NeurIPS.

  27. Maclaurin, D., Duvenaud, D., & Adams, R. (2015). Gradient-based hyperparameter optimization through reversible learning. ICML.

  28. Gomez, A. N., et al. (2017). The reversible residual network: Backpropagation without storing activations. NeurIPS.

  29. Micikevicius, P., et al. (2018). Mixed precision training. ICLR.

  30. Nair, V., & Hinton, G. E. (2010). Rectified linear units improve restricted boltzmann machines. ICML.

  31. Hendrycks, D., & Gimpel, K. (2016). Gaussian error linear units (GELUs). arXiv preprint arXiv:1606.08415.

  32. Clevert, D. A., Unterthiner, T., & Hochreiter, S. (2016). Fast and accurate deep network learning by exponential linear units (ELUs). ICLR.

  33. Klambauer, G., et al. (2017). Self-normalizing neural networks. NeurIPS.

  34. Ioffe, S., & Szegedy, C. (2015). Batch normalization: Accelerating deep network training by reducing internal covariate shift. ICML.

  35. Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer normalization. arXiv preprint arXiv:1607.06450.

  36. Zhang, J., He, T., Sra, S., & Jadbabaie, A. (2019). Why gradient clipping accelerates training: A theoretical justification for adaptivity. ICLR.

  37. Merity, S., Keskar, N. S., & Socher, R. (2018). An analysis of neural language modeling at multiple scales. arXiv preprint arXiv:1803.08240.

  38. Smith, L. N. (2017). Cyclical learning rates for training neural networks. WACV.

  39. Loshchilov, I., & Hutter, F. (2017). Decoupled weight decay regularization. ICLR.

  40. Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimization. ICLR.

  41. Bottou, L. (2010). Large-scale machine learning with stochastic gradient descent. COMPSTAT.

  42. Robbins, H., & Monro, S. (1951). A stochastic approximation method. The Annals of Mathematical Statistics, 22(3), 400-407.