3. FlashAttention
介绍 FlashAttention 技术在注意力机制中的优化
参考资料
Block-sparse FlashAttention:Block Sparse Flash Attention
FlashAttention
基本思想
FlashAttention 提出了一种显存高效、GPU / TPU friendly 的注意力实现,在保持数值精度的前提下显著降低显存使用,并加速训练和推理
主要贡献包括:
计算 softmax 时,不需要全量 input 数据,可以分段计算
反向传播的时候,不存储 $N^2$ 注意力矩阵,而是只存储 softmax 归一化的系数
其优化思路来自于 GPU 的层级结构:
GPU 中不同硬件层级在带宽与存储容量上存在显著差异
共享内存 SRAM
虽然容量极小,但具备极高的带宽
以 A100 GPU 为例,其芯片包含 108 个流式多处理器(SM),每个 SM 配备约 192 KB 的片上 SRAM
因此总体 SRAM 容量约为 $192\ \text{KB} \times 108 \approx 20\ \text{MB}$
尽管容量有限,SRAM 的带宽却可高达约 19 TB/s,远超 GPU 其他层级的存储
普通显存
A100 的 HBM(High Bandwidth Memory,即通常所说的 GPU 显存)容量在 40–80 GB 范围内,显著大于SRAM,但其带宽仅约 1.5 TB/s
这一对比体现出典型的“容量越大、带宽越低、访问延迟越高”的存储层级规律
FlashAttention 的核心动机在于最大化利用 GPU 片上 SRAM 的超高带宽,然而,SRAM 容量极其有限,通常无法容纳完整的注意力矩阵或标准矩阵乘法所需的全部中间张量
针对这一瓶颈,FlashAttention 的基本策略是对注意力计算过程进行精细化分块,将原本的大规模矩阵运算拆解为可在 SRAM 中容纳的小型计算单元
这些子任务按顺序流式处理,使计算始终在片上完成,从而显著减少对带宽较低的 HBM 的访问次数,实现 Attention 的 IO-optimal 计算
为了实现减少对高带宽内存的访问,具体实现包括两个方面:
增量计算 Softmax 缩减
为避免一次性访问整个输入序列以计算 Softmax,FlashAttention 将输入划分为若干块(block),并在每个输入块上多次迭代计算,从而以增量方式完成 Softmax 缩减
设输入矩阵为 $X \in \mathbb{R}^{L \times d}$,注意力权重为 $A = \text{Softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V$,则 FlashAttention 对每个块 $X_i$ 分别计算局部最大值和归一化因子,再在全局范围内进行累积,从而高效获得最终 Softmax 输出,而无需同时存储完整的 $QK^\top$
后向传播中减少中间矩阵存储
在标准 Attention 实现中,反向传播通常需要保存中间矩阵 $S$ 和 $P$(例如 $\text{Softmax}(QK^\top)$,这也就是 Softmax 中间缓存),其尺寸与序列长度 $L$ 呈二次关系,导致 HBM 占用量巨大
FlashAttention 则通过仅保存归一化因子来避免存储完整的中间注意力矩阵,从而显著降低内存消耗
Softmax 分块化算法
1. 数值稳定的 Softmax
Softmax 运算涉及指数函数,为避免数值溢出,可对输入向量每个元素减去最大值(max-shift)
定义如下:
m(x)f(x)ℓ(x):=imaxxi:=[ex1−m(x),…,exB−m(x)]:=i∑f(x)i则 Softmax 表示为:
softmax(x)=ℓ(x)f(x)
2. 分块计算 Softmax
考虑将一行数据 $x \in \mathbb{R}^{2B}$ 切分为两部分 $x = [x^{(1)}, x^{(2)}]$,其中 $x^{(1)}, x^{(2)} \in \mathbb{R}^B$。则 Softmax 可按块计算如下:
最大值分解:
m(x)=max(m(x(1)), m(x(2)))指数向量分解:
f(x)=[em(x(1))−m(x)f(x(1)),,,em(x(2))−m(x)f(x(2))]归一化项(分母)分解:
ℓ(x)=em(x(1))−m(x)ℓ(x(1))+em(x(2))−m(x)ℓ(x(2))最终 Softmax 表达:
softmax(x)=ℓ(x)f(x)可见,在计算各块 $f(x)$ 时,需要乘上不同的系数 $e^{m(x^{(k)}) - m(x)}$,但整体结果等价于对整行向量使用数值稳定的 Softmax 运算
例如,对于 $x^{(1)}$:
em(x(1))−m(x)f(x(1))=em(x(1))−m(x)[ex1(1)−m(x(1)),…,exB(1)−m(x(1))]=[ex1(1)−m(x),…,exB(1)−m(x)]因此,分块计算 Softmax 保持了与整行计算一致的数值稳定性
算法公式
假设序列长度为 $n$,查询 $Q$、键 $K$、值 $V$ 被划分为 $B$ 个 block,每个 block 的长度为 $b = \frac{n}{B}$
1. 序列划分与分块处理
将 $Q, K, V \in \mathbb{R}^{n \times d}$ 按行划分为 $B$ 个块:
Q=[Q1 Q2 ⋮ QB]K=[K1 K2 ⋮ KB]V=[V1 V2 ⋮ VB]每个 block 大小为 $Q_i, K_i, V_i \in \mathbb{R}^{b \times d}$,由此通过块级计算来避免一次性构造完整的 $QK^\top$ 矩阵($n \times n$)
2. 前向传播
Step 1: 局部相似度计算——对于每个查询块 $Q_i$ 和键块 $K_j$,计算局部点积:
Sij=QiKj⊤∈Rb×bStep 2: 局部最大值与归一化因子
为了稳定 Softmax 并实现增量计算,引入局部最大值 $m_{ij}$ 和局部归一化因子 $l_{ij}$:
mij=max(Sij),lij=∑exp(Sij−mij)这些局部统计量将用于全局 Softmax 的累积更新,而无需存储完整 $S$ 矩阵
Step 3: 增量 Softmax 累积
对于查询块 $Q_i$,按顺序处理每个键块 $K_j$,维护:
全局最大值 $m_i = \max_j m_{ij}$
累积归一化因子 $L_i = \sum_j l_{ij} \cdot \exp(m_{ij} - m_i)$
然后增量更新输出:
Oi←Oi+exp(Sij−mi)Vj/Li这样可以按块计算 Softmax 输出,避免一次性构建 $n \times n$ 注意力矩阵
3. 后向传播
在标准 Attention 中,反向传播需要保存完整 Softmax 矩阵 $\text{Softmax}(QK^\top)$
FlashAttention 通过只存储归一化因子 $L_i$ 和局部最大值 $m_i$,在反向传播中逐块重新计算梯度:
对每个 block $Q_i, K_j, V_j$,重新计算局部 Softmax
使用局部 Softmax 与存储的归一化因子组合,得到 $\frac{dL}{dQ}, \frac{dL}{dK}, \frac{dL}{dV}$
由于增量 Softmax 保证数值稳定性,梯度可以在块级别计算,而无需构造完整中间矩阵
FlashAttention 的计算复杂度依旧是 $O(n^2 d)$,但显存占用从标准注意力的 $O(n^2 + n d)$ 降为 $O(n d)$
FlashAttention 并没有减少模型实际进行的浮点运算次数(Floating Point Operations, FLOPs),但显存瓶颈大幅降低,因此实际训练速度提升
FlashAttention2
FlashAttention3
FlexAttention
Block-Sparse FlashAttention
Last updated
Was this helpful?