Skip to content

论文解读|Interleaved Head Attention:打破注意力头的「信息孤岛」

封面

论文信息

  • 标题:Interleaved Head Attention
  • 作者:Sai Surya Duvvuri, Chanakya Ekbote, Rachit Bansal, Rishabh Tiwari, Devvrit Khatri, David Brandfonbrener, Paul Liang, Inderjit Dhillon, Manzil Zaheer
  • 机构:Meta FAIR / UT Austin / UC Berkeley / Harvard / MIT
  • 发表时间:2026 年 2 月 | arXiv: 2602.21371

TL;DR

大模型回答"《霍比特人》的作者出生在哪?"这类问题时,需要先找到"作者是托尔金",再找到"托尔金出生于南非"——这是一个两步推理。但 Transformer 的核心组件——多头注意力(MHA)有一个根本缺陷:每个注意力头只能学一种关注模式。一个头要么学"找作者",要么学"找出生地",没法在一个头里完成两步。想处理 k 步推理?就得堆 k 个头,成本线性增长。

本文提出的 Interleaved Head Attention(IHA,交错头注意力) 解决了这个问题。核心做法是:在计算注意力之前,先把所有头的 Q、K、V 互相混合,生成一组"伪头"。由于伪 Query 和伪 Key 来自不同头的混搭,它们的交叉组合能在一个头内产生多种关注模式——P 个伪头可以组合出 P² 种模式,而不是原来的 1 种。举个例子:要同时处理 9 种不同步数的推理链,MHA 需要 9 个头各管一种;而 IHA 只需要 3 个头——因为 3 个伪 Query × 3 个伪 Key 交叉组合,恰好产生 9 种模式,一一覆盖。整个过程只加了很少的参数,而且完全兼容 FlashAttention,不需要改底层 CUDA 代码。

实测效果(2.4B 模型,FLOP 对齐对比):在 RULER 长上下文基准的 Multi-Key Retrieval(从长文本中同时找到多个散布的关键值)任务上,准确率提升 10-20%(4K-16K 上下文);在 GSM8K 小学数学推理上,多数投票准确率提升 5.8%


一、问题:MHA 的「头间隔离」瓶颈

Multi-Head Attention(MHA)是当代大模型的核心计算原语。但它有一个常被忽略的根本性限制:

H 个注意力头 → 恰好 H 个独立的注意力矩阵,头间零通信。

一个具体的例子

为了理解这个限制为什么重要,先看一个简单的多步推理问题:

"《霍比特人》的作者出生在哪里?"

上下文中散布着若干事实:

  • 事实 ①:"J.R.R. 托尔金写了《霍比特人》"
  • 事实 ②:"J.R.R. 托尔金出生于南非"

模型需要做两步:

  1. 第一步:从"《霍比特人》"关注到事实 ①,找到"托尔金"(这是一个直接关联)
  2. 第二步:从"托尔金"关注到事实 ②,找到"南非"(这是在第一步结果上的二次关联)

现在问题来了:在 MHA 中,每个注意力头只有一个 Query-Key 点积矩阵,它只能学习一种 attention pattern——要么学"作品→作者"的关系,要么学"人物→出生地"的关系,不能在一个头内同时做两步

那要同时完成两步怎么办?只能用两个不同的头,各负责一步。如果推理链更长(三步、四步...),就需要更多的头。

把直觉形式化:多项式图滤波器

论文用多项式图滤波器(Polynomial Graph Filters)精确刻画了这个问题。

把上面的例子抽象成一张图:每条事实是一个节点,两条事实如果共享实体(比如都提到"托尔金")就连一条边,邻接矩阵记为 A。那么:

  • AX(一跳聚合)= 每个节点收集直接邻居的信息 → 对应一步推理("霍比特人"找到"托尔金")
  • A²X(两跳聚合)= 信息再传播一轮 → 对应两步推理("霍比特人"经由"托尔金"找到"南非")
  • A^{k-1}X(k-1 跳聚合)→ 对应 k-1 步推理

要让模型在单层注意力中同时捕获 1 跳、2 跳、...、k-1 跳的信息(即同时处理不同长度的推理链),就需要并行计算 [X, AX, A²X, ..., A^{k-1}X]。

而 MHA 的瓶颈在于:一个注意力头只能实现一个矩阵幂 A^i——因为每个头只有一对 Q/K 投影,只能产生一个 attention 矩阵,对应一种固定的跳数。要表示从 A⁰ 到 A^{k-1} 共 k 种不同幂次,就必须用 k 个独立的头。

定理 1:用单层线性 MHA 表示 k 阶多项式滤波器,至少需要 k 个注意力头,参数复杂度为 Θ(kn²)

回到霍比特人的例子:如果上下文中同时存在需要 1 步、2 步、3 步才能回答的问题,MHA 至少需要 3 个头分别处理——参数需求随推理步数线性增长


二、IHA:让注意力头「交叉对话」

2.1 核心思想

MHA 的根源问题在于:每个头的 Query 只能与自己的 Key 交互,产生一种 attention pattern。那如果允许一个头的 Query 与其他头的 Key 也交互呢?

这正是 IHA 的做法:在注意力计算之前引入跨头混合(cross-head mixing),为每个头构建 P 个"伪头"(pseudo-heads),每个伪头的 Q/K/V 都是所有 H 个原始头的 Q/K/V 的学习线性组合。

回到霍比特人的例子:假设头 1 学到了"作品→作者"的 Q/K 投影,头 2 学到了"人物→出生地"的 Q/K 投影。在 MHA 中,它们各干各的。但在 IHA 中,一个伪头可以把头 1 的 Query 和头 2 的 Key 混合在一起,从而在单个伪头内实现两步推理的组合——"作品→作者→出生地"。P 个伪 Query × P 个伪 Key = P² 种组合,一个头就能覆盖多种推理链。

IHA 架构总览(论文 Figure 1)

图 1:IHA 架构。首先通过可学习的线性变换(×α_Q)在头轴(绿色)上为每个头生成 P 个伪 token,然后交错排列形成长度为 P·N 的扩展序列,最后在扩展序列上执行标准因果自注意力(使用滑动窗口控制计算复杂度)。

2.2 算法六步走

Step 1:标准投影 — 和 MHA 一样,将输入 X 通过 W_Q, W_K, W_V 投影得到 Q, K, V,形状 [N, H, d]。

Step 2:伪头混合(核心创新) — 为每个头 h 构建 P 个伪头。以 Query 为例:

Q~h,j=m=1Hαm,h,jQXWQ(m)

其中 α^Q ∈ R^{H×H×P} 是可学习的混合权重。K、V 同理。通常设 P = H

直觉上,每个伪 Query 不再只"看"自己头的 Key,而是看到了所有头的 Key 的混合版本。

Step 3:交错排列(Interleave) — 将伪头维度合并到序列维度,形成长度为 N·P 的扩展序列:

(token₁, pseudo₁), (token₁, pseudo₂), ..., (token₁, pseudo_P),
(token₂, pseudo₁), ..., (token_N, pseudo_P)

这种交错设计与 RoPE 天然兼容——每个伪头 token 获得独立的位置编码相位。

Step 4:标准注意力 — 在扩展后的序列上执行标准 scaled dot-product attention。

这是 IHA 的关键工程优势:完全兼容 FlashAttention——不像 Talking Heads 那样需要在 softmax 前后加额外的线性层。

Step 5-6:拆分与折叠 — 将输出 reshape 回来,通过可学习的折叠矩阵 R ∈ R^{H×P} 合并伪头结果,拼接所有头并通过输出投影。

2.3 为什么是「二次扩展」?

这是 IHA 最重要的数学洞察:

MHAIHA
每头注意力模式数1
H 个头总共HH·P²
表示 k 种模式需要的头数Θ(k)Θ(√k)

原因:P 个伪 Query 与 P 个伪 Key 之间形成 P×P 的交叉交互矩阵,每个组合 (伪Q_i, 伪K_j) 对应一种独立的注意力模式。

打个比方:MHA 就像你有 H 把钥匙和 H 把锁,每把钥匙只能开自己的锁,H 把钥匙开 H 把锁。IHA 则是把这些钥匙和锁"混搭"出 P 把新钥匙和 P 把新锁——任意一把新钥匙都可以尝试任意一把新锁,P×P = P² 种开锁方式。所以表达 k 种模式,MHA 需要 k 个头,IHA 只需要 √k 个。


三、理论保证

3.1 严格泛化 MHA

定理 2(IHA 超集性质):对于任意 P ≥ 2,MHA 可表示的函数类是 IHA 的严格子集

M ⊊ P_P

额外参数开销仅为 4H²P(3H²P 来自 α^Q, α^K, α^V 混合权重 + H²P 来自折叠矩阵 R),相对于模型总参数量微不足道。

为什么是严格超集? 论文用了一个精妙的构造性证明:在"所有 token 相同"的特殊输入上,MHA 的每个头的 attention score 矩阵各行完全相同(因为所有 Q、K 都一样),softmax 后变成均匀分布,因此输出必然是输入的线性函数。但 IHA 可以通过设置正负伪头(α = +1 和 α = -1),使 softmax 矩阵出现非均匀的结构,从而产生输入的非线性函数

3.2 多项式滤波器效率

表示 k 阶多项式图滤波器所需资源:

参数量头数
MHAΘ(k·n²)k
IHAΘ(√k·n²)⌈√k⌉

参数效率提升了 √k 倍

直觉:假设需要实现 9 种推理步数(k=9),MHA 需要 9 个头,每个头专门负责一种(A⁰, A¹, ..., A⁸)。而 IHA 只需要 3 个头(⌈√9⌉ = 3),因为它能把每个步数 t 分解为两个因子的乘积

步数 t分解为 (h, j)伪Q 用 A^伪K 用 A^乘积 = A^t
0(1,1)A⁰A⁰A⁰
1(1,2)A⁰
2(1,3)A⁰
3(2,1)A⁰
...............
8(3,3)A⁶A⁸

3 个 Query 基底 × 3 个 Key 基底 = 9 种组合,恰好覆盖所有 9 种推理步数。这就是"二次扩展"在多项式滤波器上的具体体现:用 √k 个基底的交叉组合替代 k 个独立的头。

3.3 CPM-3 任务(有序计数匹配)

对于需要处理所有有序 token 对的 CPM-3 任务:

所需头数注意力计算量
MHAN_maxΘ(N³_max)
IHA⌈√N_max⌉Θ(N^{2.5}_max)

头数从 N_max 降到 √N_max,背后的原理与多项式滤波器相同——分解乘法结构


四、实验结果

所有实验基于 2.4B 参数的 decoder-only Transformer,在 128 张 H200 GPU 上训练 240B tokens,严格 FLOP 匹配对比。

4.1 长上下文检索

在 64K 上下文窗口下微调后,用 RULER 基准评测:

RULER 长上下文结果(论文 Figure 2)

图 2:(a) Multi-Key Retrieval 准确率,橙色标注 IHA 相对提升;(b) RULER 整体精确匹配得分。

IHA 在 Multi-Key Retrieval 任务上的提升非常显著:

上下文长度IHA 相对 Global Attention 提升
4K+27%
8K+32%
16K+112%

RULER 整体精确匹配得分:IHA 44.0% vs Global+Local 40.6% vs Diff Transformer 37.2% vs Global Attention 35.0%。

这个结果非常有说服力——上下文越长,IHA 的优势越大。Multi-Key Retrieval 需要从上下文中同时定位多个散布的关键信息并聚合,恰好是 IHA 的跨头混合最擅长的场景。

4.2 推理评估

预训练模型(5-shot)

模型GSM8K EMGSM8K Maj@5MATH-500 EMMBPP P@1HumanEval P@1平均排名↓
IHA8.34% (+2.73)8.42% (+2.81)3.54% (+0.66)24.5% (+1.1)17.1% (–0.1)1.4
Global Attention5.61%5.61%2.88%23.4%17.2%2.9
Global+Local6.82% (+1.21)6.90% (+1.29)2.26% (–0.62)23.6% (+0.2)16.0% (–1.2)2.9
Talking Heads5.46% (–0.15)5.38% (–0.23)23.8% (+0.4)16.0% (–1.2)4.0
Diff Transformer5.46% (–0.15)5.61% (±0)25.0% (+1.6)15.4% (–1.8)3.5

表 1:预训练模型 5-shot 评测。IHA 在推理指标上一致领先,平均排名最优。

IHA 在数学推理任务上已经展现了优势——GSM8K 提升近 50%,而这还只是预训练阶段。

SFT 后(OpenThoughts 微调)

模型GSM8K P@1GSM8K Maj@16MATH-500 P@1MATH-500 Maj@16MBPP P@1MBPP P@10平均排名↓
IHA34.3% (+4.8)54.2% (+5.8)10.0% (+1.2)18.4% (+2.8)15.5% (+0.8)41.6% (+0.4)1.5
Global Attention29.5%48.4%8.8%15.6%14.7%41.2%3.8
Global+Local26.5% (–3.0)46.9% (–1.5)7.6% (–1.2)15.0% (–0.6)15.0% (+0.3)41.9% (+0.7)4.3
Talking Heads29.3% (–0.2)49.4% (+1.0)7.8% (–1.0)18.2% (+2.6)15.9% (+1.2)43.1% (+1.9)2.5
Diff Transformer31.6% (+2.1)53.5% (+5.1)9.0% (+0.2)18.0% (+2.4)15.3% (+0.6)39.2% (–2.0)2.8

表 2:OpenThoughts SFT 后评测。IHA 在所有推理指标上排名第一。

SFT 后 IHA 的优势进一步放大:GSM8K Maj@16 达到 54.2%(+5.8%),MATH-500 Maj@16 达到 18.4%(+2.8%)。

4.3 合成推理任务

前面的理论分析预测了:IHA 在需要组合多步关系的任务上应该表现更好。论文用 Binary / Ternary Relation Composition 合成任务直接验证了这一点。

这些任务要求模型给定两个(或三个)关系表,把它们组合起来——本质上就是"A→B"+"B→C"="A→C" 的多步关系链接,与霍比特人例子中的推理结构完全一致,只是规模更大、更受控。

Binary Relation Composition

图 3:二元关系组合(两步推理)。IHA(绿色)在相近参数量下(208K vs MHA 的 207K)大幅领先,测试准确率 77.7% vs MHA 71.5%(η=10⁻³),提升 6.2 个百分点。

Ternary Relation Composition

图 4:三元关系组合(三步推理)。推理步数增加后 IHA 优势更加明显:83.1% vs MHA 78.1%(η=10⁻³),提升约 5 个百分点。这符合理论预测——推理链越长(k 越大),IHA 的 √k 优势越显著。


五、计算开销与工程设计

全局 IHA 的计算复杂度为 O(P²·N²·d),是标准注意力的 P² 倍——这看起来代价很大,但论文通过混合调度策略解决了这个问题:

4 层滑动窗口 IHA(窗口大小 W = N/(2P²))+ 1 层全局 IHA,交替执行。

这使得平均计算量与全局注意力基线严格匹配(FLOP-matched),确保实验对比的公平性。

与其他方法的对比

方法混合层级FlashAttention 兼容交互模式数
Talking Heads注意力 logits/weightsH
Knocking Heads注意力权重H
Diff Transformer注意力 map 差分2H
IHAQ/K/V 投影层(注意力前)H·P²

IHA 的独特优势:在注意力算子之前做混合,保留了标准 softmax(QK^T)V 的形式,因此可以直接调用 FlashAttention 等高效 kernel,无需任何自定义 CUDA 代码。


六、关键洞察与讨论

一句话串起来

再回到开头的问题——"《霍比特人》的作者出生在哪里?":

  • MHA 的困境:头 1 学会了"作品→作者"(A¹),头 2 学会了"人物→出生地"(A¹),但要在单层内完成"作品→作者→出生地"(A²),还需要第三个头专门学习两跳组合。推理步数每多一步,就需要多一个头。
  • IHA 的解法:不需要第三个头。IHA 构造一个伪 Query(混合了头 1 的 Query 投影)和一个伪 Key(混合了头 2 的 Key 投影),它们交互后自然实现了 A¹·A¹ = A²——两步推理在一个伪头的一次 attention 中完成。2 个头的 2×2 = 4 种伪头组合就能覆盖 A⁰ 到 A³ 共 4 种推理步数。
  • 实验验证:这个理论预测在实验中得到了验证——IHA 在需要从长上下文中聚合多条分散线索的 Multi-Key Retrieval 任务上提升了 27%-112%,在需要多步数学推理的 GSM8K 上提升了 5.8%。

从线性到二次:乘法的力量

IHA 的效率提升源自一个简单但深刻的数学原理:乘法比加法高效

MHA 的 H 个头各自独立工作 → H 种模式(加法式扩展)。IHA 的 P 个伪 Query × P 个伪 Key → P² 种模式(乘法式扩展)。这就像十进制用两位数能表示 100 个数(10×10),而不是只有 20 个(10+10)。

局限性:速度到底慢不慢?论文没说

这是本文最值得注意的缺失:论文没有报告任何实际的时间开销数据——没有 wall-clock time、没有 tokens/second、没有训练/推理延迟对比。

论文只做了理论 FLOP 匹配:通过滑动窗口调度,把 IHA 的平均理论计算量拉到和标准注意力大致相当。但 FLOP 匹配不等于实际速度相同,至少有三个被回避的问题:

  1. 伪头混合本身的开销:虽然额外参数只有 4H²P,但每一层都要做一次跨头的 einsum 操作,这部分的实际耗时没有量化
  2. 内存访问模式变化:序列从 N 扩展到 N·P 后,FlashAttention 的分块效率、KV cache 的大小都会受影响
  3. 滑动窗口的实现效率:理论上 FLOP 对齐了,但滑动窗口在不同硬件上的实际吞吐量差异很大

对于想在工程中落地的团队来说,这些缺失的数据可能比准确率提升更关键。

其他局限性

  • 当前仅验证了 2.4B 规模,更大模型上的表现有待验证
  • 尚未扩展到非自回归架构(encoder-decoder 和 vision 场景)

七、总结

IHA 的贡献可以用一句话概括:

通过在注意力计算前引入跨头线性混合,将注意力模式从 H 种提升到 H·P²种,以 O(H²P) 的微小参数代价换取了二次级别的表达能力提升。

这篇论文的优雅之处在于:它没有发明新的注意力算子,没有引入复杂的架构,只是在 Q/K/V 投影和标准注意力之间插入了一步轻量级的线性混合——就这么简单的一步,就打破了 MHA 自 2017 年提出以来的"头间隔离"范式。

对于关注 Transformer 架构演进的研究者和工程师来说,IHA 提供了一个清晰的信号:注意力头的独立性不是不可触碰的原则,适度的跨头通信可以带来实质性的能力提升