《Attention Residuals》:让残差连接也注意力化

2026-03-19 · 5182 字 · 25 分钟

Kimi 团队 Attention Residuals 技术报告:为什么残差连接也该“注意力化”,以及 Full AttnRes / Block AttnRes 如何把这个想法做成可训练、可部署的系统

残差连接长期被当成训练稳定性的管道:让梯度更容易穿过深层网络,让旧表示不要丢得太快。但如果把模型看成信息系统,残差还有更深的问题:深度方向的信息到底该怎么路由?

《Attention Residuals》 的核心判断,是把残差连接从稳定训练的管道,重新定义成跨层信息路由。它问的不是“再加一个模块能涨几分”,而是为什么序列维度已经有注意力,深度维度还在用固定加法。

0. 先认几个词

如果你完全没有机器学习背景,可以顺着这篇报告真正关心的问题,按下面这个顺序先建立一个直觉:

  • Transformer:今天大多数大模型的基础架构。你可以先把它理解成一台一层一层处理信息的机器。
  • 隐状态(hidden state):模型在某一层里的内部中间表示。可以粗略理解成“模型此刻脑子里的临时笔记”。
  • 残差连接(residual connection):层和层之间的一条“保留旧内容”的通道。它会先把上一层的内容留住,再把这一层新算出来的东西加上去。
  • 残差项(residual):更接近“这一层新补上去的增量”,也就是上面那条残差连接里新增的那一部分。
  • 注意力(attention):从很多信息里,挑出“当前最该看哪一部分”的机制。这个词你可以先记成“有选择地看重点”。
  • PreNorm:在进入一层之前,先把数值尺度调匀,再做后续计算。可以把它想成“先把音量调到合适,再继续混音”。

1. 一句话说清楚

这份技术报告提出了一个问题:

既然 Transformer 已经用注意力机制取代了“时间维度上的递归”,为什么大模型在“深度维度上的信息聚合”还停留在固定加法?

现代大语言模型(LLM)几乎都在用一种很常见的层结构:先做 PreNorm,再走残差连接。直白地说,就是先把数值尺度调匀,再把这一层新算出来的结果加回原输入。大家熟悉它的一个功能,是让训练过程更稳定,深层网络不那么容易失控。但作者提醒我们,残差连接其实还有另一个同样重要、却长期被忽视的角色:

它定义了信息怎样沿着深度被汇总。

如果下面的式子看不熟,不用卡住,直接看后面的“翻译成人话”就够了。

标准残差的规则很简单:

hl=hl1+fl1(hl1)h_l = h_{l-1} + f_{l-1}(h_{l-1})

这里可以直接把两部分拆开看:

  • hl1h_{l-1}:旧内容,也就是上一层已经有的表示
  • fl1(hl1)f_{l-1}(h_{l-1}):这一层新算出来的增量,更接近“残差”这个词本身

而把这两部分重新加在一起的整条做法,才更准确地叫“残差连接”。

把这个递推式展开,你会得到:

hl=h1+i=1l1fi(hi)h_l = h_1 + \sum_{i=1}^{l-1} f_i(h_i)

翻译成人话就是:第 ll 层看到的输入,本质上是“嵌入表示(embedding)加上前面所有层输出的统一加总”。每一层的权重都是 1,没有选择,没有抑制,没有“这一步我更该看第 3 层还是第 17 层”的机制。

AttnRes 的核心思想只有一句话:

把残差连接从固定加法,改成沿深度做一次 softmax 注意力。

2. 旧残差到底哪里有问题?

这份技术报告最重要的地方,不在于它提出了一个新公式,而在于它把一个大家已经习惯了的东西重新问题化了。

标准残差长期被视为“训练稳定性工具”。只要能让梯度过得去,它就算完成任务了。但从信息流角度看,这条路径其实非常粗糙。

想象你在写一份持续迭代的文档。每一轮修改,你都不是“挑出最相关的旧版本内容再整合”,而是把之前所有版本一股脑全文追加到文档末尾。第 20 轮的时候,前 3 轮的重要洞察当然还在,但它们已经淹没在越来越厚的堆叠里了。

PreNorm 的问题就在这儿。报告引用了 SiameseNorm 的观察,并进一步强调:在 PreNorm 下,hidden state 的量级会随着深度近似按 O(L)O(L) 增长。这里的隐状态,说白了就是模型每一层里的那份“内部笔记”。结果就是:

  • 越往后的层,看到的是一个越来越膨胀的“历史总和”
  • 早期层的信息虽然没有消失,但会被不断稀释
  • 后面层如果还想“发出声音”,就被迫输出更大的量级

这篇技术报告把这个现象叫 PreNorm dilution。这是一个非常准确的命名。不是梯度断了,不是模型炸了,而是每一层的相对贡献被越来越稀。

报告里有一条值得抓住的潜台词:我们在序列维度上早就不满足于“所有过去词元(token)一视同仁”了,所以才有了注意力机制;那为什么到了深度维度,却还能接受“所有过去层统一权重相加”?

3. AttnRes 到底做了什么

AttnRes 的形式很干净。第 ll 层不再机械地接收“前面所有层输出的总和”,而是对这些历史表示做一次加权选择:

hl=i=0l1αilvih_l = \sum_{i=0}^{l-1} \alpha_{i \to l} \cdot v_i

其中权重 αil\alpha_{i \to l} 来自一层 softmax。你可以先把 softmax 理解成“把一组分数压成一组权重,而且所有权重加起来等于 1”,这样模型才能明确表达“更该看谁、少看谁”:

αil=softmax(wlTRMSNorm(ki))\alpha_{i \to l} = \operatorname{softmax}\left(w_l^T \operatorname{RMSNorm}(k_i)\right)

如果你没接触过注意力机制,还有一个最省力的理解方式:

  • 查询(query):当前这一层现在想找什么
  • 键(key):每一层历史信息各自贴着什么“索引标签”
  • 值(value):最后真正被取回来、参与汇总的内容

这里有三个关键设计。

第一,查询不是当前隐状态现算出来的,而是每层一个可学习的伪查询向量 wlw_l
这有点反直觉。我们平时看到注意力机制,会自然以为查询必须来自当前输入。但作者故意把查询设计成层级参数,而不是按词元动态生成的向量。这样做的好处是:同一个块里的多个查询可以提前批量算,后面基础设施优化才有空间做。

第二,键和值直接来自前面层的输出。
也就是说,真正带来“输入相关性”的不是查询,而是各层当前样本上的表示本身。不同样本经过前面层后得到的键不一样,所以最后的深度注意力依然是输入相关的。

第三,键前面加了 RMSNorm。
这是个很关键的小设计。因为如果不做归一化,量级大的层会天然在点积里占便宜,你得到的就不是“谁更相关”,而更像“谁声音更大”。报告正文也明确强调了这一点。

Python
import torch
from torch import nn
def attention_residual(
sources: list[torch.Tensor],
pseudo_query: torch.Tensor,
norm: nn.RMSNorm,
) -> torch.Tensor:
keys = torch.stack([norm(source) for source in sources], dim=0)
values = torch.stack(sources, dim=0)
logits = keys @ pseudo_query
weights = torch.softmax(logits, dim=0)
return (weights.unsqueeze(-1) * values).sum(dim=0)

这个式子看上去像是“把注意力机制用在残差连接上”。但更准确的说法是:

它把残差连接从“固定的累加器”改成了“可选择的深度检索器”。

4. 这份报告给了想法,也给了工程

一句话结论:这篇报告提出了 Full AttnRes,它把这个想法推进成了一套可训练、可部署、算得清账的工程方案。

Full AttnRes 让每一层都看到前面所有层,理论上很好理解,实际上也不算太贵。因为网络深度 LL 通常远小于序列长度 TT,所以作者说,单纯算术量 O(L2d)O(L^2 d) 并不是最可怕的问题。

真正的问题出现在大训练里:

  • 激活重计算(activation recomputation)会把本来可以丢掉的中间层输出重新变成必须保存的对象
  • 流水线并行(pipeline parallelism)会让这些跨层表示需要跨阶段传输
  • 一旦每层都要看所有前层,通信和缓存压力会快速上去

所以他们又提出了 Block AttnRes

做法是把 LL 层切成 NN 个块。块内部先用普通求和攒成一个块级表示,跨块再做注意力。这样一来:

  • Full AttnRes:看的是所有历史层
  • Block AttnRes:看的是所有历史块的摘要,再加当前块的部分和

本质上是用“摘要级跨层注意力”换取可扩展性。

作者没有只停在“分块所以省内存”这个层面,而是把系统层的账也算清楚了:

  • 训练阶段用 跨阶段缓存(cross-stage caching),避免流水线里重复传历史块
  • 推理阶段用 两阶段计算(two-phase computation)
  • 第一阶段并行算块间注意力(inter-block attention)
  • 第二阶段顺序算块内回看(intra-block lookback),再用在线 softmax 合并

从附录和 table/memory_access.tex 里能看到最硬核的一组数字。按报告给的典型设定:

  • 标准残差连接:每层残差机制 I/O 是 3d
  • naive Full AttnRes:130d
  • 优化后的 Full AttnRes:24d
  • Block AttnRes:5.5d
  • mHC:34d

这组数字特别说明问题。Block AttnRes 不是“便宜到跟标准残差连接一样”,但它已经从“明显不现实”降到了“工程上值得试”。而且报告实测给出的代价也不大:

  • 训练端实际耗时开销小于 4%
  • 推理端时延开销小于 2%

这也是它像一篇系统级技术报告的原因。很多论文的问题在于“想法是新的,账是糊的”;这篇在账本上反而做得很用力。

5. 实验最该看什么

在主实验中,AttnRes 在缩放趋势、训练动力学和下游能力上都给出了方向一致的信号。

5.1 缩放定律:不是偶然赢一把

作者先做了五个模型规模的缩放定律实验,对比 Baseline、Full AttnRes 和 Block AttnRes。

拟合出来的曲线是:

  • Baseline:1.891×C0.0571.891 \times C^{-0.057}
  • Block AttnRes:1.870×C0.0581.870 \times C^{-0.058}
  • Full AttnRes:1.865×C0.0571.865 \times C^{-0.057}

这三条曲线最重要的信息不是“斜率差了多少”,而是:

AttnRes 在整个算力区间里都持续更低。

报告给了一个很容易传播的结论:在 5.6 PFLOP/s-days 这个预算点,Block AttnRes 的损失相当于 baseline 多花 1.25x 算力才能达到的水平。

换句话说,这不是“在某个模型大小上碰巧调对了”,而是有比较稳定的规模收益。

5.2 大模型主实验:不是玩具规模

主实验不是小模型上的玩具规模基准实验,而是基于 Kimi Linear 的一个大配置:

  • 48B 总参数 / 3B 激活参数
  • 27 个 Transformer 块,也就是 54 层
  • 8-of-256 路由专家 + 1 个共享专家
  • 预训练 1.4T tokens

这说明作者不是只在“小模型上做漂亮曲线”,而是真把这个残差改造塞进了一个大训练配方里。

5.3 训练动态中,输出量级不再失控

训练动态那张图片,Baseline 的输出量级会随着深度一路涨上去。训练动态图里给的数值非常夸张:从前面几个块的 0.040.060.10,一直涨到后面几个块的 10.4712.15。这就是 PreNorm dilution 的视觉化版本。

Block AttnRes 则完全不是这条曲线。它在块边界形成一种周期性重置,量级大致在 0.211.91 之间波动,没有出现一路失控上扬。

这非常重要,因为它说明 AttnRes 不是只在最后 benchmark 上“多拿了几分”,而是真正在训练动力学层面改变了表示如何沿深度堆积。

5.4 下游任务:提升最明显的是推理和代码

预训练后,AttnRes 在报告列出的全部评测上都不差于 baseline,几个最亮眼的点包括:

  • MMLU:73.5 -> 74.6
  • GPQA-Diamond:36.9 -> 44.4
  • Math:53.5 -> 57.1
  • HumanEval:59.1 -> 62.2
  • C-Eval:79.6 -> 82.5

GPQA、Math、HumanEval 这种多步推理或程序生成任务涨幅更大,这一点直接对应了报告的机制假设。报告作者的解释是:如果后层能更有选择地回收前层表示,那么需要组合式推理的任务会更受益。这个解释是说得通的。

因为复杂推理最怕的不是“信息不存在”,而是“信息在网络很深的地方被埋住了”。

6. 消融实验告诉了我们什么

消融实验的关键结论,不是“连得更密就更强”,而是“沿深度做输入相关的选择性聚合”这件事本身在起作用。

这份报告的消融做得不错,因为它不只是证明“有用”,还试图证明“为什么有用”。

几个关键结论:

  • DenseFormer 1.767,几乎和 baseline 1.766 一样。
    这说明“能访问所有前层”本身还不够,关键在于权重是不是输入相关的。

  • mHC 到了 1.747,已经明显变好。
    这说明深度维度上的动态混合确实有效。

  • Full AttnRes 到了 1.737。
    它比 baseline、DenseFormer、mHC 都更低,说明显式的沿深度 softmax 注意力是一条更强的路线。

  • SWA(只看最近窗口)只有 1.764。
    这很有价值。它说明 AttnRes 的收益不只是“多看最近几层”,而是“能选择性地看更远的层”。

  • 块大小从 2、4、8 变化时,损失都在 1.746 左右。
    这就是为什么作者最后固定大约 8 个块。不是拍脑袋,而是工程和效果之间一个相当好的平衡点。

  • 输入相关查询版本做到 1.731,比 Full AttnRes 还好。
    这个结果说明当前报告里的伪查询设计并不是性能上限,而是一个为基础设施优化让路的折中。也就是说,作者不是不知道更强的写法,而是主动选了更容易扩展的写法。

这正是这份报告的工程取舍。你从正文、消融和系统设计里能更清楚地看到他们的真实取舍:不是盲目追求最低 loss,而是在追求“足够强,同时真能训起来”。

7. 怎么看这份报告

第一,这份报告最重要的,不是它发明了一个新模块,而是它把残差连接从“训练稳定性工具”重新提升成了“信息路由机制”。

这个视角一旦建立起来,很多东西都会被重新理解。残差不再只是梯度高速通道,它还是深度聚合规则。你会开始追问:

  • 每一层到底能不能选择性地访问前层?
  • 深度维度上有没有“注意力汇聚陷阱”(attention sink)?
  • 旧的残差变体本质上是不是沿深度维度的线性注意力?

而这正是报告讨论部分有价值的地方。作者把一堆残差变体统一进了一个 depth mixing matrix 的视角里,进一步指出:

很多已有方法,本质上都像是在深度维度上做线性注意力;AttnRes 做的是沿深度维度的 softmax 注意力。

这个说法很大胆,但也很有启发性。它等于是在说:Transformer 当年把序列维度从递归推进到了 softmax 注意力;AttnRes 试图把深度维度也推进一步。

第二,这篇技术报告的气质很像“先把问题提对,再把系统做顺”。它没有执着于把每个部件都做到最花哨。比如查询故意做成按层设定的参数,而不是按词元动态生成的向量,性能上未必绝对最强,但它给了批量计算、两阶段计算、流水线缓存一个成立的基础。很多时候,一篇能落地的技术报告,靠的不是最激进的局部设计,而是整体约束下的取舍。

第三,这份报告中的这句话:

Why is depth-wise aggregation still fixed while everything else has become adaptive?(为什么沿深度的聚合仍然是固定的,而其他部分都已经变得自适应了?)

这句话抓住了问题的中心。

8. 这份报告的边界

第一,它目前是 技术报告 / arXiv 预印本,不是已经过同行评审的会议论文。写这类文章时,最稳妥的态度不是“它已经证明了未来”,而是“它提出了一个很强的视角,并给出了一套有工程可行性的实现”。

第二,它的大规模结果主要建立在 Kimi Linear 这条架构线上:MoE、KDA/MLA 混合注意力、Moonlight / DeepSeek-V3 风格训练配方。虽然这不削弱结果本身,但也意味着我们还不能自动把结论外推到所有纯稠密的仅解码器 Transformer。

第三,报告自己也承认:Full AttnRes 其实更强,Block AttnRes 是今天硬件约束下的工程解。未来如果显存、带宽、互连再往前走,或者更高效的深度注意力变体出现,今天这版 Block 设计很可能不是终点。

9. 这篇报告改变了什么问题

Attention Residuals 的钉子句是:残差连接不只是稳定训练的管道,也是跨层信息路由规则。

这份报告把一个被默认接受的结构重新问题化了。标准 PreNorm 残差把历史层近似等权累加,训练很稳,但信息路由很粗。AttnRes 追问:既然序列维度已经用 attention 做选择,深度维度为什么还只能固定相加?

这个角度比“又多一个模块”更有价值。它把 residual connection 从梯度高速路拉回到信息系统里:每一层到底该访问哪些前层,哪些表示该被保留,哪些该被压低。Block AttnRes 的意义,也不只是省显存,而是在效果、带宽、缓存和流水线之间给出一个能训练的大规模折中。

下次看架构创新,不要只看 benchmark 涨了几分。先问它重新定义了哪条信息通路。真正重要的结构改动,往往不是添加能力,而是改变能力流动的路径。


延伸阅读

全文完 · 谢谢阅读

评论