4. Attention

介绍 Attention 机制的基本原理、Scaled Dot-Product Attention 及其常见类型

什么是 Attention

  • Attention 的本质是 “根据相关性分配权重”。模型在处理某个 token 时,会考虑“当前 token 应该关注其它哪些 token?关注多少?”

  • 信息检索类比

    • Query(Q)=“我想找什么信息”

    • Key(K)=“别人能提供什么信息”

    • Value(V)=“别人具体能给的内容”

    • Attention 就像:你(Query)→ 浏览所有人的 Key → 判断谁更相关 → 根据权重从每个人的 Value 中取信息并加权融合

  • 社交网络传播类比

    • 你(当前词)发布一个问题(Query),所有其他词用 Key 回应 “我能提供什么”,然后你按相似度给每个词评分(注意力权重),最后根据评分汇总各词提供的 Value(上下文信息)

  • 因此 Attention 最终输出一个融合上下文语义的向量

Scaled Dot-Product Attention

  • Scaled Dot-Product Attention(点乘缩放注意力) 是 Transformer 中最常用的注意力机制,由 Vaswani 等人在 2017 年提出

scaled dot-product attention

  • 其核心思想是通过计算 Query 和 Key 之间的点积来衡量相关性,并根据相关性对 Value 进行加权求和

  • 设输入为矩阵 $Q \in \mathbb{R}^{n \times d_k}$、$K \in \mathbb{R}^{n \times d_k}$、$V \in \mathbb{R}^{n \times d_v}$​​,其求解过程如下:

计算相似度

  • 第一步,通过 Query 和 Key 的点积来计算相似度

    S=QKTS = QK^T
  • 其中 $S_{ij}$ 表示第 $i$ 个 Query 与第 $j$ 个 Key 的相关性

Scaled(缩放)

  • 对计算得到的相似度进行缩放

    S~=Sdk\tilde{S} = \frac{S}{\sqrt{d_k}}
  • 这是为了避免 $S$ 的数值随维度 $d_k$ 增大而过大,导致 Softmax 梯度消失 / 梯度爆炸

  • 点积方差随维度增加而变大

    • 设 $Q$ 和 $K$ 的每个元素独立且服从均值为 0,方差为 1 的分布,$d_k$ 是向量维度,那么 $Q$ 与 $K$ 的点积:

      QK=i=1dkQiKiQ \cdot K = \sum_{i=1}^{d_k} Q_i K_i
    • 根据方差公式:$\text{Var}(Q\cdot K) = \sum_{i=1}^{d_k} \text{Var}(Q_i K_i)$

    • 假设 $\text{Var}(Q_i K_i) = 1$,则 $\text{Var}(Q\cdot K) = d_k$

    • 即,随着 $d_k$ 增大,点积的方差也会增大

  • softmax 函数对大值敏感

    si=softmax(xi)=exijexjsixi=si(1si)sixj=sisj(ij)s_i=\text{softmax}(x_i) = \frac{e^{x_i}}{\displaystyle\sum_j e^{x_j}}\\ \frac{\partial s_i}{\partial x_i} = s_i(1-s_i)\\ \frac{\partial s_i}{\partial x_j} = -s_i s_j \quad (i\neq j)
    • $x_i$ 太大

      • 指数函数膨胀 → $s_i \approx 1$、$s_j \approx 0$ → softmax 的输出接近 one-hot

      • 此时 $s_i(1-s_i) \approx 0$、$s_i s_j \approx 0$ ,即梯度都接近于 0

      • 这就叫饱和区(saturation region)

    • $x_i$ 太小 → 近似均匀 → 注意力不够尖锐

    • 当 $d_k$ 很大时,$QK^T$ 数值很大,点积的分布会随着 $d_k$ 变得越来越“宽”

    • 从而使得 softmax 输出过于尖锐 → 直接进入梯度几乎为 0 的饱和状态 → 反向传播时 $\nabla_q,\nabla_k$ 极小 → 梯度消失 → 难以训练

  • 缩放的作用:

    QKTdk\frac{QK^T}{\sqrt{d_k}}
    • 方差从 $d_k$ 缩小到 1

    • Softmax 输出保持合理分布,不至于饱和

    • 训练更加稳定,梯度不会消失或爆炸

  • 公式示意

    • 假设 $d_k = 64$,Q / K 元素服从 $\mathcal{N}(0,1)$

    • 不缩放:$\text{Var}(QK^T) = 64$ → Softmax 输出接近 one-hot

    • 缩放 $\frac{1}{\sqrt{64}} = \frac{1}{8}$ → $\text{Var} \approx 1$ → Softmax 平滑

Mask(可选)

  • Decoder 中的 causal mask

    • 保证不能看到未来

    • 此时称为 Causal-Attention,又称 Masked Self-Attention

    • 可以对上三角的注意力分数替换为 $-\infty$

  • Padding mask(忽略 PAD token)

    S~=S~+M\tilde{S}' = \tilde{S} + M
    • 在掩码矩阵 M 中,对需要屏蔽的无效位置填充 $-\infty$

    • 在对 padding 位置进行处理时,将其置为 $-\infty$ 而不是 0,是为了在后续计算 Softmax 时能够正确地屏蔽这些位置

    • 若对应位置为 0,Softmax 会产生非零权重($e^0 = 1$),从而引入错误的注意力贡献

    • 而当该位置为 $-\infty$ 时,Softmax 输出为 0($e^{-\infty} = 0$),可以将 padding 位置完全忽略

Softmax → 权重分布

  • Q 与 K 的缩放点积计算得到一个原始注意力分数矩阵,刻画了 Query 与 Key 向量之间的匹配程度,但由于尚未进行归一化,其数值范围不受约束,因而不能直接作为注意力权重来使用

  • 为了将这些原始分数转化为可解释、可用的注意力分布,对注意力矩阵的每一行独立地应用 softmax 函数,从而得到一组非负且和为 1 的权重

  • 对于特定的查询 $q_i$(对应 $Q$ 的第 $i$ 行),它与键 $k_j$(对应 $K^T$ 的第 $j$ 列)对齐的原始注意力分数记作

    sij=qikjTdks_{ij} = \frac{q_i k_j^T}{\sqrt{d_k}}
  • softmax函数将查询 $q_i$ 与所有 $N$ 个键($N$代表键/值对的序列长度)的注意力分数向量 $s_i = [s_{i1}, s_{i2}, ..., s_{iN}]$ 转换为注意力权重向量 $\alpha_i = [\alpha_{i1}, \alpha_{i2}, ..., \alpha_{iN}]$,其中每个权重 $\alpha_{ij}$ 的计算方式如下:

    αij=softmax(sij)=exp(sij)l=1Nexp(sil)\alpha_{ij} = \text{softmax}(s_{ij}) = \frac{\exp(s_{ij})}{\sum_{l=1}^N \exp(s_{il})}
  • 总而言之,通过 Softmax 函数,将原始注意力分数转为权重分布

    A=softmax(S~)A = \text{softmax}(\tilde{S}')
  • softmax 的作用

    • 归一化

    • 非负性:由于指数函数 $\exp(x)$ 的输出恒为正,因此每个注意力权重 $\alpha_{ij}$ 都保证是正数

    • 概率解释:针对查询 $q_i$ 的权重集合 ${\alpha_{i1},\alpha_{i2},\dots,\alpha_{iN}}$,代表了模型在计算第 $i$ 个位置的输出时,对输入序列中各个元素的注意力概率分布,$\alpha_{ij}$ 表示分配给第 $j$ 个输入元素的注意力比例

    • 突出重要性:指数函数会放大较大的分数,缩小较小的分数,如果某个分数 $s_{ik}$ 在该行中显著大于其他分数,对应的权重 $\alpha_{ik}$ 会趋近于1,而其他权重则趋近于0,这让模型能有效聚焦于最相关的输入元素

加权求和 Value

  • 最终公式为

    Attention(Q,K,V)=AV=Softmax(QKTdk+M)V\text{Attention}(Q, K, V) = A V = \text{Softmax}(\frac{QK^T}{\sqrt{d_k}}+M) V
  • 最终输出形状与 $V$ 相同

  • 像注意力机制这种基于矩阵运算的关键优势在于高度的并行性,其整个计算过程主要由矩阵乘法和逐行 softmax 组成,能够在为此类操作优化的硬件上高效执行

  • 与需要按时间步 $t=1,2,\cdots,n$ 顺序处理的 RNN 和 LSTM 不同,注意力机制可以近似并行地计算所有位置对 $(i,j)$ 之间的交互,从而消除了顺序计算瓶颈,使得训练速度更快,并在条件允许的情况下支持更长序列的建模

Bahdanau Attention

  • Bahdanau Attention,又称 Additive Attention,是 2015 年 Bahdanau 等人在神经机器翻译中提出的注意力机制

  • 它主要解决了 Seq2Seq 模型在长序列下Encoder 隐藏状态信息压缩成单个上下文向量的问题

  • 核心思想:

    • 不再将 Encoder 隐藏状态全部压缩成单个向量,而是将 Encoder 的所有隐藏状态输入给 Decoder

    • Decoder 在生成每个输出时,可以 “选择性关注” Encoder 的不同时间步

    • 注意力权重通过可学习的对齐模型(alignment model)计算

  • 给定 Encoder 隐藏状态序列:$H = [h_1, h_2, \dots, h_T],\quad h_i \in \mathbb{R}^{d_h}$,Decoder 当前隐藏状态:$s_t \in \mathbb{R}^{d_s}$

  • 计算对齐分数(score function)

    • Bahdanau 采用加性函数:

      eti=vtanh(Whhi+Wsst+b)e_{ti} = \mathbf{v}^\top \tanh(W_h h_i + W_s s_t + b)
    • 其中,$W_h \in \mathbb{R}^{d_a \times d_h}$,$W_s \in \mathbb{R}^{d_a \times d_s}$,$b \in \mathbb{R}^{d_a}$,$\mathbf{v} \in \mathbb{R}^{d_a}$

    • $d_a$ 是对齐向量的维度,通常小于 $d_h$ 或 $d_s$

  • 归一化注意力权重,使用 softmax 将对齐分数归一化为概率:

    αti=exp(eti)k=1Texp(etk)=exp(vtanh(Whhi+Wsst+b))kexp(vtanh(Whhk+Wsst+b))\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_{k=1}^{T} \exp(e_{tk})} = \frac{\exp(v^\top \tanh(W_h h_i + W_s s_t + b))}{\sum_k \exp(v^\top \tanh(W_h h_k + W_s s_t + b))}
    • $\alpha_{ti}$ 表示 Decoder 在时间步 $t$ 对 Encoder 隐藏状态 $h_i$ 的注意力权重

  • 将 Encoder 隐藏状态按注意力权重加权求和,得到上下文向量:

    ct=i=1Tαtihic_t = \sum_{i=1}^{T} \alpha_{ti} h_i
  • 上下文向量与 Decoder 隐藏状态结合,生成最终输出(例如预测下一个词):

    s~t=tanh(Wc[ct;st])yt=softmax(Wos~t)\tilde{s}_t = \tanh(W_c [c_t; s_t])\\ y_t = \text{softmax}(W_o \tilde{s}_t)
  • 与 Dot-Product Attention 对比

    特性
    Additive Attention
    Dot-Product (Luong, 2015)

    score function

    $v^\top \tanh(W_h h_i + W_s s_t)$

    $s_t^\top W h_i$ 或 $s_t^\top h_i$

    维度要求

    $d_h$ 与 $d_s$ 可以不同

    $d_h = d_s$(或需线性投影)

    计算复杂度

    高(需要 $\tanh$ 和向量乘法)

    低(只做矩阵乘法)

    精度表现

    对短序列效果略好

    对长序列训练效率高

  • 实际上,点乘注意力是 Bahdanau Attention 的高效近似

  • Bahdanau Attention 可以处理 Encoder 与 Decoder 不同维度的隐藏状态,其对齐模型可学习复杂非线性关系,并且动态生成上下文向量,提高长序列性能

  • 但是与 dot-product attention 相比,计算开销大。在 Transformer 出现后,由于可并行计算的优势,Additive Attention 多用于 RNN/GRU/LSTM 架构

Self-Attention

  • 自注意力(Self-Attention)用于 Transformer Encoder、Decoder

    Q=K=V=XQ = K = V = X
  • 输入序列的每个 token 与序列中的所有 token 进行注意力计算

  • 输出每个 token 的向量,是对整个序列上下文的加权融合

  • 每个词在生成自己的表示时 “看看” 其他词,按相关性分配权重

  • 自注意力是一种高层的概念,可能有不同的计算方法,如点乘缩放注意力、多头注意力

Cross-Attention

  • Cross-Attention (交叉注意力)是 Transformer 中 Encoder – Decoder 结构的核心机制之一,用于让一个序列的表示动态依赖另一个序列

  • 在 Encoder-Decoder Transformer 中:

    • Encoder 产生隐藏序列 $H^E = [h_1^E, \dots, h_T^E]$

    • Decoder 在生成每个输出时,通过 Cross-Attention 查询 Encoder 的表示

  • 在这里,Cross-Attention 的目标:

    • 根据 Decoder 当前隐藏状态(Query)从 Encoder 隐藏状态(Key/Value)中提取相关信息 —— Query 来自 Decoder,而 Key/Value 来自 Encoder

    • 动态生成上下文向量,增强输出的条件依赖性

  • 假设,Encoder 隐藏状态(Key / Value)中 $K = V = H^E \in \mathbb{R}^{T \times d_\text{model}}$,Decoder 当前隐藏状态(Query)中 $Q = H^D \in \mathbb{R}^{S \times d_\text{model}}$,其中 $T$ 表示 Encoder 序列长度,$S$ 表示 Decoder 序列长度,那么 Cross-Attention 本质是 Query 与外部序列的注意力机制,其输出为 $O \in \mathbb{R}^{S \times d_v}$

    Cross-Attention(Q,HE)=softmax(Q(HE)dk)HE\text{Cross-Attention}(Q,H^E) = \mathrm{softmax}\left(\frac{Q (H^E)^\top}{\sqrt{d_k}}\right) H^E
  • Cross-Attention 的直观理解

    • Query: Decoder 当前的 “问题” 或 “需求”

    • Key: Encoder 中各个位置的 “候选答案索引”

    • Value: Encoder 中各个位置的 “实际信息”

    • Softmax 权重:衡量 Decoder 当前需求与 Encoder 各位置信息的相关性

    • 输出是加权求和的上下文向量,指导 Decoder 生成下一步 token

    • 注意力权重 $\alpha_{ij}$ 显示 Decoder 每个输出 token 依赖 Encoder 哪些位置

特性
Self-Attention
Cross-Attention

Query、Key、Value 来源

同一序列

Query 来自 Decoder ,K / V 来自 Encoder

序列关系捕捉

内部依赖

条件依赖

Mask

Decoder 自注意力需要 causal mask

无 causal mask(可以访问整个 Encoder )

计算复杂度

$O(n^2)$

$O(S \cdot T)$

Multi-Head Attention

  • 基本思想

    • 在单头注意力中,除了输入投影外并不存在可学习的匹配参数,其相似度度量本质上是向量内积

    • 由于单一的内积形式难以捕获多样的模式,引入加性注意力以允许一个可学习的打分函数,从而获得更灵活的相似度估计

    • 多头注意力(Multi-Head Attention)通过提供 $h$ 组独立的投影,使模型能够在 $h$ 个不同的度量子空间中分别学习模式匹配策略

    • 各头在其独立空间中计算注意力后将输出拼接,并通过一次线性变换进行整合,从而提升表达能力

    • 对于第 $i$ 个头,模型使用可学习矩阵 $W_q^{(i)}, W_k^{(i)}, W_v^{(i)}$ 将 $Q, K, V$ 投影到维度 $d_v$ 的子空间,并在该子空间中计算注意力,得到 $\mathrm{head}_i$​

    multi-head attention

  • 数学公式

    • 将输入 $X$ 映射到多个不同的 $(Q_h, K_h, V_h)$ 子空间:

      Qi=XWiQ,Ki=XWiK,Vi=XWiVQ_i = X W_i^Q,\quad K_i = X W_i^K,\quad V_i = X W_i^V
      • 其中 $H$ 为注意力头数,各投影矩阵大小为:

        • $W_h^Q,W_h^K \in \mathbb{R}^{d_\text{model} \times d_k}$

        • $W_h^V \in \mathbb{R}^{d_\text{model} \times d_v}$

    • 每个头独立计算注意力:

      headh=Attention(Qh,Kh,Vh)\mathrm{head}_h=\mathrm{Attention}(Q_h,K_h,V_h)
    • 将所有头的输出拼接:

      Concat(head1,,headH)Rn×(Hdv)\mathrm{Concat}(\mathrm{head}_1,\dots,\mathrm{head}_H) \in \mathbb{R}^{n \times (H d_v)}
    • 并再通过一个线性层,重新投影到模型维度:

      MHA(X)=Concat(head1,,headH)WO\mathrm{MHA}(X) = \mathrm{Concat}(\mathrm{head}_1,\dots,\mathrm{head}_H) W^O
    • 其中 $W^O \in \mathbb{R}^{Hd_v \times d_\text{model}}$,从而 $\mathrm{MHA}(X)\in \mathbb{R}^{n \times d_\text{model}}$

    • 线性投影的作用有两个

      • 信息混合: 它使得从不同头(代表不同子空间)中学习到的信息得以组合和整合,线性层充当一个学习到的组合函数

      • 维度匹配: 它确保输出张量具有 Transformer Block 其余部分所需的 $d_{model}$ 维度,从而能够进行残差连接和归一化等操作

  • 这种方法具有多种优势:

    • 多子空间学习:每个头学习不同的关系模式,如基于语法依赖的关系、基于语义的长距离依赖、基于局部结构的相似性

    • 信息并行提取:多个注意力分布可以同时学习,使模型对输入的不同方面保持敏感

    • 稳定训练:多头结构在高维空间分散参数,使每个头的注意力权重矩阵规模较小,数值更稳定

  • 本质上,多头机制首先将模型的表示能力分解为多个子空间($h$ 个头,每个维度为 $d_v$,允许每个头在其子空间内专门关注输入的不同方面,然后使用拼接,再进行线性投影($W_O$),将这些专业化的表示合并回一个维度为 $d_{model}$ 的单一、更丰富的表示

  • 计算复杂度

    • 对于输入序列长度 $n$、模型维度 $d_{\text{model}}$,其时间复杂度(总):

      O(Hn2dk)O(H \cdot n^2 \cdot d_k)
    • 由于 $H d_k = d_\text{model}$(标准 Transformer 设置),所以整体复杂度保持不变:

      O(n2dmodel)O(n^2 d_\text{model})
    • MHA 的多头并不会增加数量级,仅是常数倍的线性投影开销

  • MHA 算法的平面展示图如下:

mha平面图

Multi-Query Attention

  • 基本思想

    • Multi-Query Attention(MQA)是为了解决 Transformer 在推理(特别是自回归生成)阶段效率低、KV cache 过大等问题而提出的注意力变体

    • 它在保持模型质量基本不降的前提下显著降低推理开销,因此 GPT-3.5、PaLM、Gemini、LLaMA-3 等均采用了 MQA 或 GQA

    • 与 MHA 相比,MQA 的关键改动:每个注意力头依然有独立的 Query 投影,所有注意力头共享同一组 Key 和 Value

  • 数学公式(对于输入 $X\in\mathbb{R}^{n \times d_\text{model}}$)

    • Query(多头):每个头有自己独立的投影

      Qh=XWhQ,h=1,,HQ_h = X W_h^Q , \quad h=1,\dots,H
    • Key / Value(单头共享):使用单组投影

      K=XWK,V=XWVK = X W^K, \quad V = X W^V
      • 注意:这里不存在 $K_h$ 或 $V_h$

    • 注意力计算,对于第 $h$ 个头:

      headh=softmax(QhKdk)V\text{head}_h = \mathrm{softmax}\left(\frac{Q_h K^\top}{\sqrt{d_k}}\right)V
    • 区别仅在:

      • MHA:每个头有自己的 $K_h, V_h$

      • MQA:所有头共享同一个 $K, V$

    • 最后拼接得到:

      MQA(X)=Concat(head1,,headH),WO\mathrm{MQA}(X) = \mathrm{Concat}(\text{head}_1,\dots,\text{head}_H), W^O
  • 为什么 MQA 能加速推理?

    • KV Cache 大幅减少

      • 自回归推理中,需要缓存历史 token 的 $(K_t, V_t)$

      • MHA 的缓存大小(per layer)为 $O(n H d_k)$,而 MQA 则为 $O(n d_k)$

      • 缓存大小直接减少 H 倍(例如 H = 32 → KV cache 减少 32×),这对长上下文推理非常关键

    • 推理时每步注意力计算更快

      • 自回归中每次新 token 的注意力计算为:

        • MHA: $H$ 组 $(1 \times d_k) \cdot (n \times d_k)$

        • MQA: $H$ 组 Query,但共享 $(n \times d_k)$ 的 Key

      • K 的读取和 V 的读取次数减少了 $H$ 倍,这在 GPU/TPU 上极大降低了 memory bandwidth 压力(推理瓶颈)

      • 因此,MQA 的推理延迟显著降低,并提升吞吐(tokens/s)

    • 为什么共享 K / V 不会降低太多质量?

      • Q 决定 “看哪里”,K / V 决定 “有哪些内容可看”

      • 多头的差异主要在 Query 的投影(注意力查询方式不同)

      • 多头的 K / V 差异相对较小,实验显示合并对模型能力影响很小

      • PaLM 论文表明:在模型较大时(几十亿参数以上),共享 K / V 对质量影响极小,但能显著提升推理效率

  • MQA 算法的平面展示图如下:

mqa平面图

Grouped-Query Attention

  • 基本思想

    • Grouped-Query Attention(GQA)是 MQA 和 MHA 的折中形式:

    • 将多头分组,每组共享 K / V

    • 例如 $H=32$​,分为 4 组,则每组 8 个头共享 K / V

  • 三种形式

    • GQA-G 是指具有 G 组的 grouped-query attention

    • GQA-1 具有单个组,因此具有单个 Key 和 Value,等效于 MQA

    • 而 GQA-H 具有与头数相等的组,等效于 MHA

  • GQA 算法的平面展示图如下:

gqa平面图

MHA / MQA / GQA

  • MHA / MQA / GQA 基本概念对比($H$ 表示注意力头数,$G$​ 表示 GQA 的分组数)

    特性
    MHA(Multi-Head Attention)
    MQA(Multi-Query Attention)
    GQA(Grouped-Query Attention)

    Query (Q)

    每头独立

    每头独立

    每头独立

    Key (K)

    每头独立

    所有头共享 1 份

    每组头共享 1 份

    Value (V)

    每头独立

    所有头共享 1 份

    每组头共享 1 份

    KV cache 内存

    最大,$O (n × H × d_k)$

    最小,$O (n × d_k)$

    中等,$O (n × G × d_k)$

    推理速度

    较慢

    快(尤其长上下文)

    较快,介于 MHA 和 MQA 之间

    表达能力

    最强

    略低于 MHA

    介于 MHA 与 MQA 之间

    使用场景

    训练/小模型/高精度

    自回归大模型推理

    训练阶段或推理折中方案

  • MHA、MQA、GQA 算法的对比如下:

mha-mqa-gqa对比

  • MHA / MQA / GQA 参数量比较(设 $d_\text{model}$ 表示隐藏维度,$H$ 表示注意力头数,$d_k = d_\text{model}/H$,$W^O \in \mathbb{R}^{H d_v \times d_\text{model}}$)

    • MHA(每头独立 Q/K/V)

      Params(MHA)=H3(dmodeldk)+dmodel(Hdv)=4dmodel2\text{Params(MHA)} = H \cdot 3 (d_\text{model} \cdot d_k) + d_\text{model} \cdot (H d_v) = 4 d_\text{model}^2
    • MQA(所有头共享 K/V,对大 $H$(如 32),参数略小于 MHA)

      Params(MQA)=H(dmodeldk)+2(dmodeldk)+dmodel(Hdv)(3+2/H)dmodel2\text{Params(MQA)} = H (d_\text{model} \cdot d_k) + 2 (d_\text{model} \cdot d_k) + d_\text{model} \cdot (H d_v) \approx (3 + 2/H) d_\text{model}^2
    • GQA(每组头共享 K/V,组数 $G$)

      Params(GQA)=H(dmodeldk)+G2(dmodeldk)+dmodel(Hdv)=(3+2GH)dmodel2\text{Params(GQA)} = H (d_\text{model} \cdot d_k) + G \cdot 2 (d_\text{model} \cdot d_k) + d_\text{model} \cdot (H d_v) = \left(3 + \frac{2G}{H}\right)d_\text{model}^2

Sparse Attention

什么是稀疏注意力

  • 标准注意力存在的缺点

    • 对长序列,计算量 $O(n^2 \cdot d_k)$,显存需求高(其中 $n$ 是序列长度)

    • 特别是在自回归推理或大 batch 训练时,矩阵 $QK^\top$ 和 softmax 中间缓冲区会占用大量显存

    • 为了降低原本 $O(n^2)$ 的注意力计算成本,设计了一系列方法,如 Sparse Attention、Performer、Linear Attention、FlashAttention

  • Sparse Attention 的核心思想

    • 稀疏注意力(Sparse Attention)只计算部分相关位置的注意力,而不是全量 $n \times n$ 计算,降低复杂度并保持性能

    • 但显然,计算部分位置的注意力的效果理论上不如计算全部位置的注意力的效果

  • Sparse Attention 的数学形式

    • 令 $S_i$ 为 Query $i$ 对应的 Key 索引集合(稀疏模式):

      αi[j]={exp(QiKj/dk)kSiexp(QiKk/dk),jSi 0,jSi\alpha_i[j] = \begin{cases} \frac{\exp(Q_i K_j^\top / \sqrt{d_k})}{\sum_{k \in S_i} \exp(Q_i K_k^\top / \sqrt{d_k})}, & j \in S_i \ 0, & j \notin S_i \end{cases}
    • 上下文向量:

      ci=jSiαi[j]Vjc_i = \sum_{j \in S_i} \alpha_i[j] V_j
    • 复杂度由 $O(n^2)$ 降为 $O(n \cdot |S_i|)$​,显著节省计算和显存

  • 为了实现稀疏注意力,可以引入稀疏模式矩阵 $M \in {0,1}^{n \times n}$,定义哪些 Query-Key 对需要计算注意力,也即 Masked Sparse Attention

    SparseAttention(Q,K,V)=softmax(QKdkM)V\text{SparseAttention}(Q,K,V) = \mathrm{softmax}\Big(\frac{Q K^\top}{\sqrt{d_k}} \odot M\Big) V
    • $\odot$ 表示元素级掩码

    • $M_{ij}=1$ 表示 $Q_i$ 与 $K_j$ 有效计算

    • $M_{ij}=0$ 表示跳过计算(节省 FLOPs 和内存)

    • 这种方式显著降低了计算量,复杂度 $O(n \cdot k)$,其中 $k$ 是每个 Query 有效 Key 的数量($k \ll n$)

  • Sparse Attention 的优缺点

    • 通过引入稀疏注意力,提高了 Transformer 的长序列可扩展性,并降低计算量与显存消耗

    • 同时,可以组合多种掩码模式(局部/随机/跨步长),保证模型捕捉局部和长程依赖

      • 但是也存在一些缺点:

        • 稀疏模式设计难:不同任务可能需要不同模式(Local / Strided / Global)

        • 梯度传播受限:某些未被注意的 Key 无法直接影响输出,需要全局 token 或跳跃连接

        • 硬件优化复杂:GPU / TPU 更适合连续矩阵乘法,稀疏模式可能导致内存访问不连续

  • Sparse Attention 的常见模式

    • Local / Sliding Window Attention(局部窗口注意力)

      • 每个 Query 仅关注固定窗口 $w$ 内的 Key

      • 适合长文本、图像 patch 或时间序列

      • 复杂度 $O(n \cdot w)$,通常 $w \ll n$

    • Strided / Dilated Attention

      • 每个 Query 按固定步长访问 Key(如每隔 $s$ 个位置)

      • 可扩大感受野而保持稀疏

      • 结合局部窗口可捕捉长程依赖

    • Global + Local Attention

      • 少数全局 token(CLS 或 summary token)关注全序列

      • 大部分 token 仅做局部注意力

      • 代表模型:Longformer、BigBird

    • Random / Learned Sparse Patterns

      • 每个 Query 随机选择 $k$ 个 Key 或通过训练学习稀疏模式

      • BigBird 使用了 Random + Local + Global 组合,保证稀疏性的同时保持理论上各个 token 之间可完全互连

  • Sparse Attention 代表模型

    模型
    Sparse Attention 模式
    复杂度

    Sliding window + global token

    $O(n)$

    Local + Global + Random

    $O(n)$

    Local + LSH (hashing)

    $O(n\log n)$

    Kernel-based attention(稀疏近似),高效近似 softmax attention

    $O(n d^2)$

    低秩矩阵近似

Performer

  • Performer(Choromanski et al., 2020)提出线性注意力,通过核方法近似 softmax,实现 $O(n d^2)$ 复杂度,显著提升长序列可扩展性

  • 核心思想

    • 核心问题:标准注意力中 $\text{softmax}(Q K^T)$ 是 $n \times n$ 矩阵,计算量 $O(n^2)$

    • 将 softmax attention 视为核函数:

      softmax(QiKj)=ϕ(Qi)ϕ(Kj)\text{softmax}(Q_i K_j^\top) = \phi(Q_i)^\top \phi(K_j)
      • 其中 $\phi(\cdot)$ 是正交随机特征映射

    • 利用线性化性质:

      Attention(Q,K,V)i=jsoftmax(QiKj)Vjjϕ(Qi)ϕ(Kj)Vj\text{Attention}(Q,K,V)_i = \sum_j \text{softmax}(Q_i K_j^\top) V_j \approx \sum_j \phi(Q_i)^\top \phi(K_j) V_j
    • 重新排列求和顺序:

      Attention(Q,K,V)iϕ(Qi)(jϕ(Kj)Vj)\text{Attention}(Q,K,V)_i \approx \phi(Q_i)^\top \left(\sum_j \phi(K_j) V_j \right)
    • 关键优化:先对 K / V 做线性组合,再与 Q 做点积,实现线性复杂度

  • 数学公式

    • 定义映射 $\phi: \mathbb{R}^{d} \rightarrow \mathbb{R}^{r}$,例如:

      ϕ(x)=exp(ωxx2/2),ωN(0,I)\phi(x) = \exp(\omega^\top x - |x|^2 / 2), \quad \omega \sim \mathcal{N}(0, I)
    • Attention 线性化:

      LinearAttention(Q,K,V)=ϕ(Q)(ϕ(K)V)\text{LinearAttention}(Q,K,V) = \phi(Q) (\phi(K)^\top V)
      • $\phi(Q) \in \mathbb{R}^{n \times r}$

      • $\phi(K)^\top V \in \mathbb{R}^{r \times d_v}$

      • 输出 $O = \phi(Q) (\phi(K)^\top V) \in \mathbb{R}^{n \times d_v}$

    • 正则化处理(确保概率性质):

      Oi=ϕ(Qi)jϕ(Kj)Vjϕ(Qi)jϕ(Kj)O_i = \frac{\phi(Q_i)^\top \sum_j \phi(K_j) V_j}{\phi(Q_i)^\top \sum_j \phi(K_j)}

Linformer

RoPE 中的旋转注意力

  • 本质是把位置信息编码进 $Q/K$:

    Q~=QRθ,K~=KRθ\tilde{Q} = Q R_{\theta},\quad \tilde{K} = K R_{\theta}
  • 拥有相对位置性质

  • 支持更强的外推能力(extreme long context)

ALiBi 中的线性偏置注意力

  • 通过给注意力打分加入位置偏置:

    Sij=QiKjT+mh(ij)S_{ij} = Q_i K_j^T + m_h \cdot (i-j)
  • 适用于超长上下文任务,训练短上下文 → 推理超长上下文

Last updated

Was this helpful?