3. FlashAttention

介绍 FlashAttention 技术在注意力机制中的优化

FlashAttention

基本思想

  • FlashAttention 提出了一种显存高效、GPU / TPU friendly 的注意力实现,在保持数值精度的前提下显著降低显存使用,并加速训练和推理

  • 主要贡献包括:

    • 计算 softmax 时,不需要全量 input 数据,可以分段计算

    • 反向传播的时候,不存储 $N^2$ 注意力矩阵,而是只存储 softmax 归一化的系数

  • 其优化思路来自于 GPU 的层级结构:

    • GPU 中不同硬件层级在带宽与存储容量上存在显著差异

    • 共享内存 SRAM

      • 虽然容量极小,但具备极高的带宽

      • 以 A100 GPU 为例,其芯片包含 108 个流式多处理器(SM),每个 SM 配备约 192 KB 的片上 SRAM

      • 因此总体 SRAM 容量约为 $192\ \text{KB} \times 108 \approx 20\ \text{MB}$

      • 尽管容量有限,SRAM 的带宽却可高达约 19 TB/s,远超 GPU 其他层级的存储

    • 普通显存

      • A100 的 HBM(High Bandwidth Memory,即通常所说的 GPU 显存)容量在 40–80 GB 范围内,显著大于SRAM,但其带宽仅约 1.5 TB/s

      • 这一对比体现出典型的“容量越大、带宽越低、访问延迟越高”的存储层级规律

    • FlashAttention 的核心动机在于最大化利用 GPU 片上 SRAM 的超高带宽,然而,SRAM 容量极其有限,通常无法容纳完整的注意力矩阵或标准矩阵乘法所需的全部中间张量

    • 针对这一瓶颈,FlashAttention 的基本策略是对注意力计算过程进行精细化分块,将原本的大规模矩阵运算拆解为可在 SRAM 中容纳的小型计算单元

    • 这些子任务按顺序流式处理,使计算始终在片上完成,从而显著减少对带宽较低的 HBM 的访问次数,实现 Attention 的 IO-optimal 计算

  • 为了实现减少对高带宽内存的访问,具体实现包括两个方面:

    • 增量计算 Softmax 缩减

      • 为避免一次性访问整个输入序列以计算 Softmax,FlashAttention 将输入划分为若干块(block),并在每个输入块上多次迭代计算,从而以增量方式完成 Softmax 缩减

      • 设输入矩阵为 $X \in \mathbb{R}^{L \times d}$,注意力权重为 $A = \text{Softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V$,则 FlashAttention 对每个块 $X_i$ 分别计算局部最大值和归一化因子,再在全局范围内进行累积,从而高效获得最终 Softmax 输出,而无需同时存储完整的 $QK^\top$

    • 后向传播中减少中间矩阵存储

      • 在标准 Attention 实现中,反向传播通常需要保存中间矩阵 $S$ 和 $P$(例如 $\text{Softmax}(QK^\top)$,这也就是 Softmax 中间缓存),其尺寸与序列长度 $L$ 呈二次关系,导致 HBM 占用量巨大

      • FlashAttention 则通过仅保存归一化因子来避免存储完整的中间注意力矩阵,从而显著降低内存消耗

Softmax 分块化算法

  • 1. 数值稳定的 Softmax

    • Softmax 运算涉及指数函数,为避免数值溢出,可对输入向量每个元素减去最大值(max-shift)

    • 定义如下:

      m(x):=maxixif(x):=[ex1m(x),,exBm(x)](x):=if(x)i\begin{align} m(x) &:= \max_i x_i \\ f(x) &:= \left[e^{x_1 - m(x)}, \ldots, e^{x_B - m(x)}\right] \\ \ell(x) &:= \sum_i f(x)_i \end{align}
    • 则 Softmax 表示为:

      softmax(x)=f(x)(x)\operatorname{softmax}(x) = \frac{f(x)}{\ell(x)}
  • 2. 分块计算 Softmax

    • 考虑将一行数据 $x \in \mathbb{R}^{2B}$ 切分为两部分 $x = [x^{(1)}, x^{(2)}]$,其中 $x^{(1)}, x^{(2)} \in \mathbb{R}^B$。则 Softmax 可按块计算如下:

    • 最大值分解:

      m(x)=max(m(x(1)), m(x(2)))m(x) = \max\big(m(x^{(1)}),\ m(x^{(2)})\big)
    • 指数向量分解:

      f(x)=[em(x(1))m(x)f(x(1)),,,em(x(2))m(x)f(x(2))]f(x) = \left[ e^{m(x^{(1)}) - m(x)} f(x^{(1)}) ,,, e^{m(x^{(2)}) - m(x)} f(x^{(2)}) \right]
    • 归一化项(分母)分解:

      (x)=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2))\ell(x) = e^{m(x^{(1)}) - m(x)} \ell(x^{(1)}) + e^{m(x^{(2)}) - m(x)} \ell(x^{(2)})
    • 最终 Softmax 表达:

      softmax(x)=f(x)(x)\operatorname{softmax}(x) = \frac{f(x)}{\ell(x)}
    • 可见,在计算各块 $f(x)$ 时,需要乘上不同的系数 $e^{m(x^{(k)}) - m(x)}$,但整体结果等价于对整行向量使用数值稳定的 Softmax 运算

    • 例如,对于 $x^{(1)}$:

      em(x(1))m(x)f(x(1))=em(x(1))m(x)[ex1(1)m(x(1)),,exB(1)m(x(1))]=[ex1(1)m(x),,exB(1)m(x)]\begin{align} e^{m(x^{(1)}) - m(x)} f(x^{(1)}) &= e^{m(x^{(1)}) - m(x)} \left[e^{x_1^{(1)} - m(x^{(1)})}, \ldots, e^{x_B^{(1)} - m(x^{(1)})}\right] \\ &= \left[e^{x_1^{(1)} - m(x)}, \ldots, e^{x_B^{(1)} - m(x)}\right] \end{align}
    • 因此,分块计算 Softmax 保持了与整行计算一致的数值稳定性

算法公式

  • 假设序列长度为 $n$,查询 $Q$、键 $K$、值 $V$ 被划分为 $B$ 个 block,每个 block 的长度为 $b = \frac{n}{B}$

  • 1. 序列划分与分块处理

    • 将 $Q, K, V \in \mathbb{R}^{n \times d}$ 按行划分为 $B$ 个块:

      Q=[Q1 Q2  QB]K=[K1 K2  KB]V=[V1 V2  VB]Q = \begin{bmatrix} Q_1 \ Q_2 \ \vdots \ Q_B \end{bmatrix}\\ K = \begin{bmatrix} K_1 \ K_2 \ \vdots \ K_B \end{bmatrix}\\ V = \begin{bmatrix} V_1 \ V_2 \ \vdots \ V_B \end{bmatrix}
    • 每个 block 大小为 $Q_i, K_i, V_i \in \mathbb{R}^{b \times d}$,由此通过块级计算来避免一次性构造完整的 $QK^\top$ 矩阵($n \times n$)

  • 2. 前向传播

    • Step 1: 局部相似度计算——对于每个查询块 $Q_i$ 和键块 $K_j$,计算局部点积:

      Sij=QiKjRb×bS_{ij} = Q_i K_j^\top \in \mathbb{R}^{b \times b}
    • Step 2: 局部最大值与归一化因子

      • 为了稳定 Softmax 并实现增量计算,引入局部最大值 $m_{ij}$ 和局部归一化因子 $l_{ij}$:

        mij=max(Sij),lij=exp(Sijmij)m_{ij} = \max(S_{ij}), \quad l_{ij} = \sum \exp(S_{ij} - m_{ij})
      • 这些局部统计量将用于全局 Softmax 的累积更新,而无需存储完整 $S$ 矩阵

    • Step 3: 增量 Softmax 累积

      • 对于查询块 $Q_i$,按顺序处理每个键块 $K_j$,维护:

        • 全局最大值 $m_i = \max_j m_{ij}$

        • 累积归一化因子 $L_i = \sum_j l_{ij} \cdot \exp(m_{ij} - m_i)$

      • 然后增量更新输出:

        OiOi+exp(Sijmi)Vj/LiO_i \gets O_i + \exp(S_{ij} - m_i) V_j / L_i
      • 这样可以按块计算 Softmax 输出,避免一次性构建 $n \times n$ 注意力矩阵

  • 3. 后向传播

    • 在标准 Attention 中,反向传播需要保存完整 Softmax 矩阵 $\text{Softmax}(QK^\top)$

    • FlashAttention 通过只存储归一化因子 $L_i$ 和局部最大值 $m_i$,在反向传播中逐块重新计算梯度:

      • 对每个 block $Q_i, K_j, V_j$,重新计算局部 Softmax

      • 使用局部 Softmax 与存储的归一化因子组合,得到 $\frac{dL}{dQ}, \frac{dL}{dK}, \frac{dL}{dV}$

    • 由于增量 Softmax 保证数值稳定性,梯度可以在块级别计算,而无需构造完整中间矩阵

  • FlashAttention 的计算复杂度依旧是 $O(n^2 d)$,但显存占用从标准注意力的 $O(n^2 + n d)$ 降为 $O(n d)$

  • FlashAttention 并没有减少模型实际进行的浮点运算次数(Floating Point Operations, FLOPs),但显存瓶颈大幅降低,因此实际训练速度提升

FlashAttention2

FlashAttention3

FlexAttention

Block-Sparse FlashAttention

Last updated

Was this helpful?