当一个大语言模型输出"巴黎是法国的首都"时,这句话并非在最后一层突然涌现。在Transformer的数十层神经网络中,每一层都在逐步构建、修正、精炼这个预测。但如何窥视这个"黑箱"内部的思考过程?

2020年8月,一位网名为"nostalgebraist"的研究者在LessWrong上发布了一篇博客文章,提出了一个看似简单却极具洞察力的技术:将Transformer中间层的隐藏状态直接投影到词汇空间。这个被称为"Logit Lens"的方法,让研究者第一次能够看到模型在每一层的"猜测"。

三年后,Nora Belrose等人发表的Tuned Lens论文,进一步完善了这一技术,使其能够应用于更多类型的模型。这两项工作共同构成了理解Transformer内部计算过程的核心工具。

从输出回溯:Transformer的预测是如何形成的

要理解Logit Lens,首先需要理解Transformer的基本结构。以GPT-2为例,一个具有48层的模型处理文本时,每一层都会接收一个向量,对其进行变换,然后输出给下一层。这个过程中,信息不断被加工、组合、精炼。

传统上,研究者只能看到模型的最终输出——经过48层处理后,模型对下一个词的预测分布。但这个最终结果是如何一步步形成的?中间层究竟在做些什么?

Logit Lens的核心思想出奇简单:既然每一层都会产生一个隐藏状态向量,而这个向量最终会被转换成对词汇表的概率分布,那为什么不提前做这个转换?

具体来说,Transformer的最后一层通常包含一个"解嵌入矩阵"(unembedding matrix)$W_U$,它将隐藏状态映射到词汇空间。对于一个词汇表大小为$V$、隐藏维度为$d$的模型,$W_U$的形状是$d \times V$。给定最后一层的隐藏状态$h_L$,logits的计算方式是:

$$\text{logits} = h_L \cdot W_U$$

Logit Lens所做的,就是对每一层的隐藏状态$h_\ell$都应用这个投影:

$$\text{logits}_\ell = h_\ell \cdot W_U$$

这样,我们就能够看到模型在每一层"认为"下一个词应该是什么。

graph TD
    A[输入Token] --> B[Embedding层]
    B --> C[Layer 1]
    C --> D[Layer 2]
    D --> E[...]
    E --> F[Layer L]
    F --> G[最终输出]
    
    C --> H[Logit Lens投影]
    D --> I[Logit Lens投影]
    F --> J[Logit Lens投影]
    
    H --> K[中间层预测分布]
    I --> L[中间层预测分布]
    J --> M[最终预测分布]
    
    style H fill:#e1f5fe
    style I fill:#e1f5fe
    style J fill:#c8e6c9

一个具体的观察:预测是如何逐层演化的

让我们看一个具体的例子。当GPT-2处理"语言模型是一种"这段文本时,模型需要预测下一个词。

在早期层(比如第1-5层),logit lens显示的top-1预测往往是通用的、信息量低的词,如"一种"、“一个”、“这个”——这些词几乎可以在任何上下文中出现。

到了中间层(第10-20层),预测开始变得更有针对性。“工具”、“技术”、“方法"这类词开始占据主导地位。模型似乎已经理解了这是一个定义性的句子结构。

在后期层(第30层以后),预测进一步精确化,可能出现"AI技术”、“机器学习方法"等更具体的表述,最终收敛到模型认为最合适的答案。

graph LR
    subgraph "Layer 1-5 早期层"
        A1["一种<br/>P=0.15"]
        A2["一个<br/>P=0.12"]
        A3["这个<br/>P=0.10"]
    end
    
    subgraph "Layer 10-20 中间层"
        B1["工具<br/>P=0.22"]
        B2["技术<br/>P=0.18"]
        B3["方法<br/>P=0.15"]
    end
    
    subgraph "Layer 30+ 后期层"
        C1["AI技术<br/>P=0.35"]
        C2["算法<br/>P=0.25"]
        C3["程序<br/>P=0.15"]
    end
    
    A1 --> B1
    A2 --> B2
    B1 --> C1
    B2 --> C2
    
    style A1 fill:#ffcdd2
    style B1 fill:#fff9c4
    style C1 fill:#c8e6c9

这个观察揭示了一个重要事实:Transformer的计算是一个渐进的精炼过程,而非突然的顿悟。每一层都在为上一层的输出添加信息,逐步将模糊的猜测转化为精确的预测。

意外的发现:输入被"遗忘"了

nostalgebraist在原始博客中报告了一个令人惊讶的发现:输入层的表示与后续所有层的表示都截然不同。

如果计算输入层(embedding层)的logit lens输出,它实际上是在尝试预测下一个词——但仅基于当前单个词的信息。由于缺乏上下文,这个预测通常毫无意义。然而,一旦进入第一个Transformer层,表示就会发生剧烈变化,与输入层产生巨大的KL散度。

graph TD
    subgraph "表示空间分析"
        E[Embedding层<br/>KL散度: 极高] --> |"剧烈变化"| L1[Layer 1]
        L1 --> |"平滑过渡"| L2[Layer 2]
        L2 --> |"平滑过渡"| L3["..."]
        L3 --> |"平滑过渡"| LN[Layer N<br/>最终输出]
    end
    
    subgraph "信息流向"
        I[输入信息] --> |"被重构"| R[预测表示]
    end
    
    style E fill:#ffcdd2
    style L1 fill:#fff9c4
    style LN fill:#c8e6c9

这意味着什么?模型并非简单地"保持"输入信息然后逐步添加处理。相反,输入在进入第一个Transformer层时就被彻底重构了。模型几乎立即丢弃了"这看起来像输入token X"的信息,转而开始构建"这看起来应该输出token Y"的表示。

这个发现挑战了一种直觉:人们可能认为早期层负责"理解输入”,后期层负责"生成输出"。实际上,整个网络从头到尾都在做输出预测,只是预测的质量逐步提高。

技术局限:Logit Lens并非万能钥匙

尽管Logit Lens提供了有价值的洞见,但它存在明显的局限性。

模型依赖性是最突出的问题。原始Logit Lens在GPT-2上表现出色,但在其他模型架构上效果参差不齐。例如,在BLOOM和OPT-125M等模型上,logit lens的中间层预测往往不连贯,甚至在超过一半的层中,top-1预测是输入token本身,而非任何合理的续写。

**表示漂移(Representational Drift)**是另一个关键问题。不同层的隐藏状态具有不同的统计特性。具体而言,隐藏状态的协方差矩阵会随着层数增加而逐渐变化。当我们将早期层的隐藏状态直接投影到词汇空间时,这个投影可能产生误导性的结果——不是因为早期层"不知道"答案,而是因为早期层的表示使用了不同的"语言"(不同的基向量)。

偏差问题也不容忽视。Logit Lens对某些词汇存在系统性偏好。在GPT-Neo-2.7B上,研究者发现logit lens的边际分布与模型最终输出的边际分布之间存在4-5比特的KL散度。这意味着logit lens可能会系统性地高估或低估某些词的概率,使其成为一个有偏的估计器。

graph LR
    subgraph "Logit Lens的问题"
        A[模型依赖性] --> A1["某些模型效果差<br/>BLOOM/OPT/GPT-Neo"]
        B[表示漂移] --> B1["不同层使用不同的基向量<br/>协方差矩阵漂移"]
        C[偏差问题] --> C1["对某些词系统性偏好<br/>KL散度可达4-5比特"]
    end

Tuned Lens:用学习来校正漂移

2023年3月,Nora Belrose、Zach Furman等人在论文"Eliciting Latent Predictions from Transformers with the Tuned Lens"中提出了一种改进方案。

Tuned Lens的核心思想是:与其直接使用最后一层的解嵌入矩阵,不如为每一层学习一个专门的"翻译器"。

具体而言,对于层$\ell$,Tuned Lens引入一个可学习的仿射变换$(A_\ell, b_\ell)$,将隐藏状态$h_\ell$映射到一个"校正后"的表示:

$$h'_\ell = A_\ell \cdot h_\ell + b_\ell$$

然后再应用解嵌入矩阵:

$$\text{logits}_\ell = h'_\ell \cdot W_U$$

这个仿射变换通过最小化KL散度来训练:

$$\mathcal{L} = \mathbb{E}\left[D_{KL}(p_{\text{final}} \| p_\ell)\right]$$

其中$p_{\text{final}}$是模型最终层的输出分布,$p_\ell$是Tuned Lens在层$\ell$的输出分布。

graph TD
    subgraph "Logit Lens"
        H1[隐藏状态 h] --> LN1[LayerNorm]
        LN1 --> U1[解嵌入矩阵 W_U]
        U1 --> L1[Logits]
    end
    
    subgraph "Tuned Lens"
        H2[隐藏状态 h] --> T[可学习仿射变换<br/>A·h + b]
        T --> LN2[LayerNorm]
        LN2 --> U2[解嵌入矩阵 W_U]
        U2 --> L2[Logits]
    end
    
    style T fill:#e1bee7

这个训练目标的精妙之处在于:它使用模型自身的最终输出作为"软标签",而非外部的真实标签。这确保了Tuned Lens学习的是"模型知道什么",而非"我们希望模型知道什么"。

Tuned Lens的效果是显著的。在GPT-Neo-2.7B上,Logit Lens直到第21层才能产生有意义的预测,而Tuned Lens从第1层就能看到逐渐收敛的预测轨迹。困惑度(perplexity)也大幅降低——Tuned Lens的中间层预测困惑度远低于Logit Lens,更接近模型最终输出的困惑度。

graph LR
    subgraph "困惑度对比 GPT-Neo-2.7B"
        direction TB
        A["Layer 1<br/>Logit Lens: 100+<br/>Tuned Lens: 25"]
        B["Layer 10<br/>Logit Lens: 50<br/>Tuned Lens: 15"]
        C["Layer 21<br/>Logit Lens: 20<br/>Tuned Lens: 12"]
        D["Layer 24<br/>最终: 10"]
    end
    
    A --> B --> C --> D
    
    style A fill:#ffcdd2
    style C fill:#c8e6c9

从观察到干预:预测轨迹的实际应用

Logit Lens和Tuned Lens不仅是观察工具,它们还为实际应用打开了大门。

幻觉检测是其中一个重要方向。当大语言模型生成虚假信息时,这些信息往往在预测轨迹中留下痕迹。一项研究发现,当模型产生幻觉时,预测轨迹可能表现出异常的模式:某些层突然偏离了合理的预测路径,或者在后期层出现剧烈的波动。通过监测这些轨迹异常,可以构建幻觉检测器。

知识定位与编辑是另一个关键应用。Kevin Meng等人在NeurIPS 2022上发表的论文"Locating and Editing Factual Associations in GPT"中,结合因果追踪技术,发现事实知识存储在特定的MLP层中。Logit Lens帮助研究者定位了哪些层负责特定事实的检索,进而实现了精确的知识编辑——例如,将"埃菲尔铁塔位于罗马"这样的反事实知识写入模型,同时保持其他知识不变。

graph TD
    subgraph "知识编辑流程"
        A[输入: 埃菲尔铁塔位于] --> B[前向传播]
        B --> C{因果追踪<br/>定位关键层}
        C --> |"Layer 5 MLP"| D[识别存储位置]
        D --> E[应用ROME编辑]
        E --> F[输出: 罗马]
    end
    
    subgraph "Logit Lens作用"
        LL1[逐层预测] --> LL2[识别知识浮现层]
        LL2 --> LL3[指导编辑位置选择]
    end
    
    C -.-> LL1
    
    style C fill:#e1bee7
    style E fill:#c8e6c9

恶意输入检测也展现出潜力。Belrose等人的研究表明,Tuned Lens的预测轨迹可以用于检测提示注入攻击。恶意构造的输入会导致预测轨迹呈现独特的不连续性,正常输入则表现出平滑的收敛模式。在实验中,这种方法在某些场景下实现了接近完美的检测准确率。

理解Few-shot学习的机制也得到了推进。研究者使用Logit Lens分析模型如何处理few-shot示例,发现了一个有趣现象:当提供错误的示例时,早期层的预测往往比最终层的预测更准确。这意味着模型的"直觉"可能是正确的,但在后期层被错误示例"误导"了。

逐层预测的数学本质

要真正理解Logit Lens,需要从迭代推理的视角看待Transformer。

Transformer的残差连接结构可以表示为:

$$h_{\ell+1} = h_\ell + f_\ell(h_\ell)$$

其中$f_\ell$是第$\ell$层的变换(包括自注意力和MLP)。这个结构暗示了一种"增量更新"的计算模式:每一层在上一层的表示基础上添加一个小的修正。

从迭代推理的角度,可以将$h_\ell$视为模型在第$\ell$步对下一个词的"潜在预测",而$f_\ell(h_\ell)$则是基于当前信息对这个预测的"更新"。

Logit Lens所做的,就是将这个潜在预测显式地解码到词汇空间。如果这个视角正确,我们期望看到:随着层数增加,预测的困惑度单调递减,预测分布逐渐收敛到最终输出。

graph TD
    subgraph "残差流: 信息传递总线"
        H0["h₀ = Embedding"] --> |"+ f₀"| H1["h₁"]
        H1 --> |"+ f₁"| H2["h₂"]
        H2 --> |"..."| HN["h_N"]
        
        subgraph "各层贡献"
            F0["f₀: 语法结构"]
            F1["f₁: 语义关系"]
            FN["f_N: 推理整合"]
        end
        
        H0 -.-> F0
        H1 -.-> F1
    end
    
    style H0 fill:#ffcdd2
    style HN fill:#c8e6c9

实验结果很大程度上支持了这个预测。在大多数情况下,Tuned Lens显示的预测轨迹确实表现出平滑的单调收敛特性。但也有一些例外——特别是在处理复杂推理任务时,预测轨迹可能呈现非单调的行为,暗示着模型内部的"辩论"或"冲突解决"过程。

残差流与信息传递

Logit Lens的发现与"残差流"(Residual Stream)的概念密切相关。Anthropic的研究者在"A Mathematical Framework for Transformer Circuits"中提出,Transformer的残差连接不仅仅是网络优化的技巧,而是信息传递的核心机制。

在这种视角下,残差流就像一条"通信总线",各个层通过注意力头和MLP向这条总线"写入"或从中"读取"信息。每个层的输出都可以分解为:保留原有信息(恒等映射)加上新添加的信息(残差分支)。

Logit Lens的观察——早期层就与输入表示截然不同——可以用这个框架来解释。第一层并不试图"保持"输入表示,而是立即开始向残差流"写入"预测相关的信息。随着层数增加,这些信息不断累积和精炼,最终形成完整的预测。

这种设计有一个重要优势:它允许不同层专注于不同类型的信息处理。早期层可能负责提取基本语法结构,中间层处理语义关系,后期层进行推理和知识检索——所有这些都通过共享的残差流协调。

实现Logit Lens:关键细节

要在实践中实现Logit Lens,有几个技术细节需要注意。

首先,LayerNorm的处理至关重要。在Pre-LN(层归一化在前)架构中——这是现代大模型的标配——隐藏状态在进入每一层之前会先经过归一化。因此,Logit Lens也应该对中间层的隐藏状态应用归一化,然后再投影到词汇空间。

其次,对于权重共享的模型(如GPT-2,输入embedding矩阵与输出解嵌入矩阵共享权重),可以直接使用这个共享矩阵进行投影。但对于不共享权重的模型,需要使用专门的解嵌入矩阵。

第三,Tuned Lens的训练需要选择合适的数据。研究者建议使用模型预训练时的验证集数据,而非任意文本。这确保了translator学习的是模型"熟悉"的表示分布。

以下是一个简化的Logit Lens实现逻辑:

def logit_lens(hidden_states, unembed_matrix, ln_final):
    """
    hidden_states: list of tensors, each [batch, seq_len, hidden_dim]
    unembed_matrix: [hidden_dim, vocab_size]
    ln_final: LayerNorm module
    """
    layer_predictions = []
    
    for h in hidden_states:
        # 应用与最终层相同的归一化
        h_normed = ln_final(h)
        # 投影到词汇空间
        logits = h_normed @ unembed_matrix
        layer_predictions.append(logits)
    
    return layer_predictions

Tuned Lens在此基础上增加了可学习的仿射变换:

class TunedLens:
    def __init__(self, num_layers, hidden_dim, vocab_size):
        # 为每一层创建translator
        self.translators = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) 
            for _ in range(num_layers)
        ])
        self.unembed = nn.Linear(hidden_dim, vocab_size, bias=False)
    
    def forward(self, hidden_states, layer_idx):
        h = hidden_states[layer_idx]
        h_translated = self.translators[layer_idx](h)
        return self.unembed(h_translated)

超越文本:Logit Lens的扩展

Logit Lens的思想不仅限于文本模型,研究者已将其扩展到其他领域。

视觉语言模型方面,2024年CVPR论文"Devils in Middle Layers of Large Vision-Language Models"将logit lens应用于图像token的隐藏状态,揭示了多模态模型如何逐步理解视觉信息。研究发现,图像理解呈现出与文本类似的多阶段模式:早期层识别基本视觉元素,中间层整合语义信息,后期层进行任务相关的推理。

扩散模型的文本编码器也成为了研究对象。Diffusion Lens技术揭示了文本到图像模型中的文本编码器如何将文本提示转换为视觉概念表示。一个有趣的发现是:复杂概念需要更多层才能形成忠实表示,而常见概念在早期层就能浮现。

graph TD
    subgraph "视觉语言模型的Logit Lens"
        I[图像输入] --> V[视觉编码器]
        T[文本输入] --> TE[文本编码器]
        V --> F[特征融合层]
        TE --> F
        F --> L[语言模型层]
        L --> O[输出]
        
        V --> VL1[Logit Lens<br/>视觉理解轨迹]
        L --> VL2[Logit Lens<br/>文本生成轨迹]
    end
    
    style VL1 fill:#e1f5fe
    style VL2 fill:#c8e6c9

扩散模型本身的去噪过程也受到了类似的分析。研究者开发了针对扩散模型生成过程的"lens"技术,观察噪声如何逐步转化为清晰图像。

理论意义:理解与可解释性

Logit Lens的价值超越了其技术实用性,它为理解神经网络的"思维过程"提供了一个具体的窗口。

长期以来,神经网络被视为"黑箱":我们知道它们有效,但难以解释为什么有效。Logit Lens提供了一个反例——它让我们能够逐层追踪模型的推理过程,看到预测是如何从模糊到清晰,从通用到具体。

这种能力对于AI安全和可解释性研究尤为重要。如果能够理解模型在每一层的"意图",就更容易检测模型是否在"欺骗"用户,是否在使用不安全的推理路径,是否在生成有害内容。这些都是构建可信AI系统的关键信息。

同时,Logit Lens也揭示了深度学习的一些本质特性。它显示了大模型并非简单地记忆训练数据,而是真正在进行某种形式的推理——即使这种推理可能与人类的思维方式截然不同。预测轨迹的平滑收敛暗示着一种"理性"的更新过程,而轨迹中的突然变化可能标志着关键的推理步骤。

graph LR
    subgraph "可解释性技术栈"
        A[Logit Lens<br/>预测轨迹] --> B[注意力可视化<br/>信息流动]
        B --> C[稀疏自动编码器<br/>特征分解]
        C --> D[因果追踪<br/>关键节点]
        D --> E[完整理解<br/>模型机制]
    end
    
    style A fill:#e1bee7
    style E fill:#c8e6c9

局限与未来方向

尽管Logit Lens和Tuned Lens取得了显著成功,它们仍面临挑战。

计算开销是实际应用中的障碍。要获得完整的预测轨迹,需要对每个token的前向传播进行修改,存储每一层的隐藏状态。对于超大规模模型,这可能带来显著的内存压力。

因果解释的局限也需要注意。Logit Lens显示的是"如果在这一层停止会预测什么",但这不意味着模型在正常推理中真的会这样做。残差连接意味着后续层的输出会覆盖(或修正)早期层的表示,因此早期层的预测可能从未真正影响最终输出。

多token生成场景下的应用仍需探索。Logit Lens主要分析单个位置下一个token的预测,但实际生成涉及多token序列。如何扩展到序列级别的分析是一个开放问题。

与其他可解释性技术的整合是未来方向之一。例如,将Logit Lens与注意力可视化、稀疏自动编码器、因果追踪等技术结合,可能提供更全面的模型理解。

技术选型的权衡

在实践中,Logit Lens与Tuned Lens各有适用场景。

Logit Lens的优势在于零成本部署——它不需要任何额外训练,可以直接应用于任何兼容架构的模型。这使得它非常适合快速原型开发和初步分析。

Tuned Lens的优势在于更高的准确性和更广的适用性——它能够处理Logit Lens失效的模型架构,并提供更低困惑度的中间层预测。但代价是需要额外的训练数据和计算资源。

对于研究者而言,一个实用的策略是:先用Logit Lens进行快速探索,如果发现预测质量不足或不连贯,再考虑训练Tuned Lens。

从"看见"到"理解"

Logit Lens和Tuned Lens代表了大模型可解释性研究的一个重要里程碑。它们不仅提供了观察模型内部状态的工具,更重要的是,它们揭示了Transformer计算的某种本质:迭代精炼、渐进预测、残差流信息传递。

这些发现正在改变我们设计和理解大模型的方式。当模型表现不佳时,Logit Lens可以帮助诊断问题出在哪个阶段。当模型产生有害输出时,预测轨迹可能揭示不安全的推理路径。当需要编辑模型知识时,这些工具帮助定位目标信息的位置。

当然,我们距离真正"理解"神经网络仍有很长的路。Logit Lens展示的是模型"预测什么",而非"如何预测"。要回答后者,需要更深入地分析注意力头、MLP层以及它们之间的交互。但Logit Lens已经证明:神经网络内部并非完全不可窥视,至少在某种程度上,我们能够看到它们"在想什么"。


参考文献

  1. nostalgebraist. “Interpreting GPT: the logit lens.” LessWrong, August 2020.

  2. Belrose, N., Furman, Z., Smith, L., Halawi, D., Ostrovsky, I., McKinney, L., Biderman, S., & Steinhardt, J. “Eliciting Latent Predictions from Transformers with the Tuned Lens.” arXiv:2303.08112, March 2023.

  3. Meng, K., Bau, D., Andonian, A., & Belinkov, Y. “Locating and Editing Factual Associations in GPT.” NeurIPS 2022.

  4. Elhage, N., Nanda, N., Olsson, C., Henighan, T., Joseph, N., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., DasSarma, N., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Jones, A., Kernion, J., Lovitt, L., Ndousse, K., Amodei, D., Brown, T., Clark, J., Kaplan, J., McCandlish, S., & Olah, C. “A Mathematical Framework for Transformer Circuits.” Anthropic, 2021.

  5. Geva, M., Schuster, R., Berant, J., & Levy, O. “Transformer Feed-Forward Layers Are Key-Value Memories.” EMNLP 2021.

  6. Halawi, D., Denain, J. S., & Steinhardt, J. “Overthinking the Truth: Understanding how Language Models Process False Demonstrations.” ICLR 2024.

  7. Toker, D., Orgad, H., Ventura, M., Belinkov, Y., & De Giorgi, L. “Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines.” ACL 2024.

  8. Jiang, X., Chen, Y., Wang, J., Zhang, W., & Wang, L. “Devils in Middle Layers of Large Vision-Language Models: Interpreting, Detecting and Mitigating Object Hallucinations.” CVPR 2025.