1. 分布式训练

介绍分布式训练的基本概念与技术

什么是分布式训练

分布式训练的基本概念

  • 模型训练一般包括三个步骤

    • 通过模型传递输入以产生输出的前向传播

    • 反向传播计算梯度

    • 利用梯度更新参数

  • 分布式训练关心的三个核心问题

    • 内存占用(Memory Usage):这是一个硬性限制——如果一次训练步骤无法装入内存,训练就无法继续进行

    • 计算效率(Compute Efficiency):我们希望硬件的大部分时间都用于实际计算,因此需要减少花在数据传输上的时间,或等待其他 GPU 完成工作的空闲时间

    • 通信开销(Communication Overhead):我们希望将通信带来的额外开销降到最低,因为通信会使 GPU 处于空闲状态。为此,需要尽可能充分利用节点内(速度更快)和节点间(速度较慢)的带宽,并尽量将通信与计算进行重叠——此也即关键改进点,Overlap computation and communication(计算与通信重叠)

  • 主节点(Master Node)

    • 主节点通常负责管理任务调度、参数初始化和分布式通信的协调

    • 在 PyTorch 分布式环境中,主节点通常具有以下职责:

      • 初始化分布式环境(init_process_group

      • 分发模型参数或数据到其他节点

      • 处理日志记录或统计信息

    • 在 PyTorch 中,主节点由环境变量 MASTER_ADDRMASTER_PORT 指定

    • 主节点并不一定承担实际训练任务,可能只负责通信和协调

  • 进程组(Process Group)

    • 进程组是 PyTorch 分布式通信的基本单元,表示一组参与通信的分布式进程,定义哪些进程参与分布式操作

    • 提供通信接口,如广播(Broadcast)、全归约(AllReduce)等

    • 初始化:

      • 使用 torch.distributed.init_process_group() 初始化

      • 支持多种后端,如 NCCL(GPU)、Gloo(CPU/GPU)和 MPI

    • 如果有 4 个 GPU 参与训练,可以初始化一个包含 4 个进程的进程组,每个 GPU 对应一个进程

  • Rank

    • 在分布式训练中,Rank 是用于标识每个进程的唯一编号,用于区分不同的计算任务和通信操作

    • Global Rank:在整个分布式系统中唯一标识一个进程

    • Local Rank:在某个节点(物理机器)中标识该节点内的进程编号

    • Rank 决定进程的角色,例如哪一个进程是主节点

    • 在多 GPU 训练中,Rank 通常和 GPU ID 对应

    • 如果有 8 个 GPU 和 2 个节点,每个节点有 4 个 GPU:

      • 节点 1 的全局 Rank 为 0, 1, 2, 3

      • 节点 2 的全局 Rank 为 4, 5, 6, 7

  • NCCL(NVIDIA Collective Communication Library)

    • NVIDIA 提供的高性能通信库,专为 GPU 集群上的深度学习训练优化

    • 提供高效的分布式通信操作,包括广播、归约(Reduce)、全归约(AllReduce)、聚合(AllGather)等

    • 利用 GPU 的高速 NVLink、PCIe 和 InfiniBand 实现低延迟、高带宽通信

      • 支持多 GPU 和多节点通信

      • 深度集成到 PyTorch 和 TensorFlow 等框架中

      • 自动优化通信拓扑,减少通信瓶颈

  • Backend:分布式通信的底层实现方式,在 PyTorch 中常见的后端包括:

    • NCCL:用于 GPU 通信,性能最佳

    • Gloo:支持 CPU 和 GPU,适用于多种环境

    • MPI:使用消息传递接口,适合高性能计算环境

分布式训练库

torchrun

  • 功能

    • torchrun 是 PyTorch 提供的一个 CLI 工具,用于管理分布式训练任务,特别是基于 torch.distributed 的分布式环境

    • 负责初始化 PyTorch 的分布式训练环境

    • 自动设置分布式所需的主节点地址 (MASTER_ADDR) 和端口 (MASTER_PORT)

    • 简化进程启动,支持多 GPU 和多节点训练

    • 替代了旧的 torch.distributed.launch 工具,提供更易用和灵活的接口

  • 工作原理

    • torchrun 为每个 GPU 启动一个进程

    • 使用 NCCL(默认后端)或其他后端初始化通信

    • 进程通过 torch.distributed.init_process_group 互相通信,完成梯度同步

  • 单卡多 GPU

  • 多卡多 GPU

Accelerate

  • 功能

    • 自动处理多 GPU、TPU 和混合精度训练(FP16)

    • 自动设备分配:在 CPU、单 GPU 和多 GPU 环境中无缝切换

    • 支持分布式训练的关键操作,如梯度累积、参数同步等

  • 示例代码

单个 GPU 的训练

Batch Size 的影响

  • Batch Size 和训练时间的关系

    • 每个 epoch 的优化器步数 = 总样本数/批次大小

    • 而每个优化器步数需要进行的操作

      • 前向传播(forward)

      • 反向传播(backward)

      • 梯度同步 / 累积(多卡时)

      • 参数更新(optimizer step)

    • backward 通常是 forward 的 2–3 倍计算量,optimizer step 涉及大量参数读写,在分布式训练中还要做梯度通信(all-reduce),所以一次 step 的固定开销很大

    • 训练的总时间 ≈ steps×( 固定成本+ 每个批次的计算时间),而当 batch 很小时固定成本占比变大,GPU 更容易空转

    • 固定成本包括

      • Kernel launch 开销(GPU):每个算子都要 launch kernel,kernel launch 本身有不可忽略的延迟,batch 小时,算子算得快,但 launch 时间不变

      • 反向传播的调度与同步:autograd 图的遍历,backward 中的算子依赖管理

      • 梯度清零(zero_grad):需要遍历全部参数,batch 再小,也要清一整套梯度

      • 优化器更新参数 Adam / AdamW:读参数、读一阶、二阶动量、写回更新结果,而参数量固定 → 则这部分成本固定

    • 每个批次的计算时间会随 batch size 增长而增长,包括矩阵乘法 FLOPs、attention 计算、token-level loss、embedding lookup 的有效计算部分

  • Batch Size Tokens

    • 在大语言模型预训练领域,batch size 通常以 token 数量来表示,而不是以样本数来表示(即 Batch Size Tokens)

    • 这种表示方式使得训练规模在数值上基本不依赖于训练过程中所使用的具体输入序列长度

    • 通常而言 Batch Size Tokens = Batch Size Samples × Sample Token Length

    • Llama 1 使用约 4M tokens 的 batch size 训练了 1.4T tokens,而 DeepSeek 使用约 60M tokens 的 batch size 训练了 14T tokens

  • 在将模型训练扩展到大规模 batch size时,面临的首要挑战就是内存不足问题。当 GPU 内存不足以容纳 batch 大小时,应该如何应对

为什么需要存储激活值

  • 从神经网络的函数复合本质与反向传播中的链式法则出发,可以直接得到核心结论

  • 反向传播在计算梯度时,所需的局部导数依赖前向传播产生的激活值,因此激活值必须在前向阶段被存储,以供反向阶段使用

  • 神经网络的本质

    • 神经网络本质上是复合函数,对任意一层前馈神经网络(含 Transformer block),其前向计算为:

      z(l)=W(l)a(l1)+b(l)a(l)=ϕ(l)(z(l))z^{(l)} = W^{(l)} a^{(l-1)} + b^{(l)}\\ a^{(l)} = \phi^{(l)}(z^{(l)})
    • 整个模型的损失函数是高度嵌套的函数复合:

      L=L(ϕ(L)(W(L)ϕ(L1)(ϕ(1)(W(1)x))))\mathcal{L}= \mathcal{L}\Big( \phi^{(L)}(W^{(L)} \phi^{(L-1)}(\cdots \phi^{(1)}(W^{(1)} x))) \Big)
    • 反向传播要做的,就是对这个复合函数逐层应用链式法则

  • 权重梯度显式依赖前向激活值

    • 考虑某一层参数的梯度:

      LW(l)\frac{\partial \mathcal{L}}{\partial W^{(l)}}
    • 由链式法则:

      LW(l)=Lz(l)z(l)W(l)\frac{\partial \mathcal{L}}{\partial W^{(l)}}= \frac{\partial \mathcal{L}}{\partial z^{(l)}} \cdot \frac{\partial z^{(l)}}{\partial W^{(l)}}
    • 而线性层有:

      z(l)W(l)=a(l1)\frac{\partial z^{(l)}}{\partial W^{(l)}} = a^{(l-1)}
    • 因此:

      LW(l)=Lz(l)a(l1)\frac{\partial \mathcal{L}}{\partial W^{(l)}}= \frac{\partial \mathcal{L}}{\partial z^{(l)}} a^{(l-1)}
    • 可见,前一层的激活值 $a^{(l-1)}$ 是权重梯度计算中不可缺失的因子,如果前向传播时没有保存 $a^{(l-1)}$,反向传播在数学上就无法进行

  • 非线性层的梯度同样依赖激活值

    • 激活层:

      a(l)=ϕ(z(l))a^{(l)} = \phi(z^{(l)})
    • 反向传播需要:

      Lz(l)=La(l)ϕ(z(l))\frac{\partial \mathcal{L}}{\partial z^{(l)}}= \frac{\partial \mathcal{L}}{\partial a^{(l)}} \cdot \phi'(z^{(l)})
    • 而 $\phi'(z^{(l)})$ 在实际中必须通过前向结果确定:

      • ReLU:$\phi'(z) = \mathbf{1}[z > 0]$,需要知道 forward 时哪些 $z$ 为正

      • Sigmoid:$\phi'(z) = \sigma(z)(1 - \sigma(z))$,通常直接由已存的 $a^{(l)}$ 计算

      • GELU / SiLU:梯度是激活值或其数值近似的函数

    • 非线性层的梯度不是常数,必须依赖 forward 的激活信息

  • 从计算图角度看:激活值是“边上的数值”

    • 自动微分将模型表示为计算图:节点——张量,边——算子

    • 反向传播在图上计算:

      yx\frac{\partial y}{\partial x}
    • 这些局部 Jacobian几乎都依赖前向阶段的具体数值

    • 因此反向传播不是符号推导,而是在已知前向数值的前提下,执行数值化的链式法则

    • 这就要求前向阶段的中间结果(激活值)必须被保存

  • 为什么不能等反向时再算激活值?

    • 理论上可以重算,但代价是:每一层反向都需要重新执行前向,而前向本身又依赖更早的激活

    • 总复杂度变为:

      O(layers×forward)O(\text{layers} \times \text{forward})
    • 而不是标准的:

      O(forward+backward)O(\text{forward} + \text{backward})
    • 这在深层网络(尤其是 Transformer)中是不可接受的

Transformer 中需要保存的激活值

  • Transformer layer 的计算结构

    • 设输入为 $X \in \mathbb{R}^{B \times S \times d}$(B 表示 Batch Size 即样本数,S 表示 Sequence Length 即Token数,d 表示嵌入向量维度)

    • 一个典型 Pre-LN Transformer layer:

      • LayerNorm₁

      • Self-Attention(QKV → scores → softmax → context)

      • Residual Add

      • LayerNorm₂

      • FFN(Linear → Activation → Linear)

      • Residual Add

  • LayerNorm 必须保存的张量

    • Forward 产生

      • 输入:$X$

      • 统计量:$\mu = \text{mean}(X), \quad\sigma^2 = \text{var}(X)$

      • 输出:$\hat{X} = \text{LN}(X)$

    • 必须保存:$\hat{X}$ 或 $X$,$\mu, \sigma^2$

    • Backward 依赖

      • 计算:$\frac{\partial \mathcal{L}}{\partial X}$

      • 必须使用 $\hat{X}, \mu, \sigma^2$

    • LayerNorm 的梯度显式依赖 forward 的统计量

  • QKV 线性映射必须保存的张量

    • Forward 产生 $Q = \hat{X} W_Q,\quad K = \hat{X} W_K,\quad V = \hat{X} W_V$

    • 必须保存 $Q, K, V$,$\hat{X}$

    • Backward 依赖

      • 参数梯度:$\frac{\partial \mathcal{L}}{\partial W_Q}=\hat{X}^\top \frac{\partial \mathcal{L}}{\partial Q}$

      • 输入梯度:$\frac{\partial \mathcal{L}}{\partial \hat{X}}=\frac{\partial \mathcal{L}}{\partial Q} W_Q^\top- \frac{\partial \mathcal{L}}{\partial K} W_K^\top- \frac{\partial \mathcal{L}}{\partial V} W_V^\top$

    • $Q/K/V$ 和 $\hat{X}$ 都不可丢

  • Attention score 与 Softmax 必须保存的张量

    • Forward 产生

      • Attention logits:$A = \frac{Q K^\top}{\sqrt{d_k}}\in \mathbb{R}^{B \times H \times S \times S}$

      • Softmax 概率:$P = \text{softmax}(A)$

      • Context:$C = P V$

    • 必须保存

      • $P$(Softmax 输出)

      • $Q, K, V$

      • $A$ 通常不存,$P$ 足够

    • Backward 依赖

      • Softmax 梯度:$\frac{\partial \mathcal{L}}{\partial A}=J_{\text{softmax}}(P)\frac{\partial \mathcal{L}}{\partial P}$

      • Score 梯度:$\frac{\partial \mathcal{L}}{\partial Q}=\frac{\partial \mathcal{L}}{\partial A} K$

    • $P$ 是整个 Transformer 中最大的激活,由于每个 token 要和序列中所有 token 计算一次相关性,故其尺寸 $\propto B \times H \times S^2$

    • 而 Transformer 的所有非 Attention 操作,都是“对每个 token 独立、逐位置计算”的,因此均 $\propto B \times S \times d$

  • Attention 输出投影必须保存的张量

    • Forward:$Y = C W_O$

    • 必须保存:$C$、$W_O$

    • Backward

      • 参数梯度:$\frac{\partial \mathcal{L}}{\partial W_O}=C^\top \frac{\partial \mathcal{L}}{\partial Y}$

  • 输入梯度:$\frac{\partial \mathcal{L}}{\partial C}=\frac{\partial \mathcal{L}}{\partial Y} W_O^\top$

  • Residual Add 必须保存的张量

    • Forward:$X_1 = X + Y$

    • 必须保存:$X$、$Y$

    • Backward:$\frac{\partial \mathcal{L}}{\partial X}\mathrel{+}= \frac{\partial \mathcal{L}}{\partial X_1}$

    • 残差路径要求保存分支输入

  • FFN 第一层(Linear₁ + Activation)

    • Forward:$H = X_2 W_1$、$A = \phi(H)$

    • 必须保存:$X_2$、$H$ 或 $A$

    • Backward 激活梯度:$\frac{\partial \mathcal{L}}{\partial H}=\frac{\partial \mathcal{L}}{\partial A}\cdot \phi'(H)$

    • 非线性决定必须保存激活相关信息

  • FFN 第二层(Linear₂)

    • Forward:$Z = A W_2$

    • 必须保存:$A$

  • 必须保存的激活张量

    模块
    必须保存的张量
    依赖原因

    LN₁

    $X, \mu, \sigma^2$

    LN 梯度

    QKV

    $\hat{X}, Q, K, V$

    参数/输入梯度

    Attention

    $P$

    Softmax Jacobian

    Attn Out

    $C$

    $W_O$ 梯度

    Residual₁

    $X, Y$

    梯度分流

    LN₂

    $X_1, \mu, \sigma^2$

    LN 梯度

    FFN₁

    $X_2, H/A$

    激活梯度

    FFN₂

    $A$

    参数梯度

    Residual₂

    $X_1, Z$

    梯度分流

  • 整个所需的激活量可以近似为

    C1BSdLN + QKV + FFN+C2BHS2Attention softmax\underbrace{C_1 \cdot B S d}_{\text{LN + QKV + FFN}} + \underbrace{C_2 \cdot B H S^2}_{\text{Attention softmax}}
  • 对于短序列($S \ll d$)而言,$B S d \gg B S^2$,FFN / QKV activation 占主导

  • 对于长序列($S \gg d$)而言,$B S^2 \gg B S d$,Attention softmax 成为绝对瓶颈,正是 Activation Recomputation、Long Context、FlashAttention、Linear Attention 等这些工作的动机来源

激活重计算(Activation Recomputation)

  • 激活重计算又称 Gradient Checkpointing、Gradient rematerialization

  • 其核心思想是在前向计算时丢弃部分激活值,然后在反向传播时重新计算这些激活值,从而用额外的计算量换取更低的显存占用

  • 如果不使用重计算,需要在两个可学习算子之间(例如前馈网络、LayerNorm 等)保存所有中间隐藏状态,以便在反向传播时直接使用它们来计算梯度

  • 而如果使用重计算,通常只在模型结构中的少数关键位置保存激活,其余激活在前向阶段直接丢弃;等到反向传播需要这些激活时,再从最近一次保存的激活出发,在线重新执行一小段前向计算把它们算出来

  • 本质上,这是重新执行一部分前向传播,用额外的计算开销来换取激活显存的显著下降

  • 在选择哪些激活作为检查点进行存储时,通常有几种策略:

    • Full(全量重计算)

      • 在这种策略下,在 Transformer 模型每一层之间的边界处保存一次激活

      • 在反向传播过程中,几乎需要对每一层的内部重新执行一次前向传播,等价于在 backward 阶段额外再跑一遍完整的 forward

      • 这种策略节省的显存最多,但在计算开销上也是最昂贵的,通常会使整体计算量和训练时间增加 30–40%

    • Selective(选择性重计算)

      • Reducing Activation Recomputation in Large Transformer Modelsarrow-up-right 这篇论文的作者分析了哪些激活张量占用内存最大,同时其重计算所需的浮点运算(FLOPs)成本最低

      • 分析结果表明,注意力相关的计算正好符合这一特征:其激活占用内存很大,但重新计算的计算代价相对较低

      • 因此,在实践中,通常会丢弃 attention 相关的激活,而保存计算代价高但内存占用相对可控的前馈网络部分的激活值

      • 以 GPT-3(175B)模型为例,这种选择性重计算策略可以在仅增加约 2.7% 的计算开销的情况下,实现约 70% 的激活显存降低

  • HFU 与 MFU

    • HFU(Hardware FLOPS Utilization):统计硬件实际执行的计算量,包括由于重计算(checkpoint / recomputation)带来的额外 FLOPs,用来衡量当前训练实现对硬件算力的利用程度

    • MFU(Model FLOPS Utilization):只统计模型在一次 forward + backward 中理论上必须执行的计算量,不包含重计算,用来衡量模型本身的计算效率,与具体实现策略弱相关

  • 如今大多数训练框架都会使用 FlashAttention,其在其优化策略中原生集成了选择性激活重计算:在反向传播时重新计算 attention 的分数和矩阵,而不是在前向传播时把它们存下来

  • 无法解决 Batch Size 带来的线性增长

    • 激活重计算的核心是减少了“每个样本”的 activation 常数项,特别是 Attention 中的 $S^2$ 激活,把“存激活”变成“多算一点”

    • 然而,它不能解决 batch 维度即 $B$ 本身带来的线性增长,即重计算降低的是“每个样本需要多少显存”,但 batch size 决定的是“要存多少个样本的激活”

    • 假设:单样本 activation(使用 FlashAttention + checkpoint)= 1.2 GB,那么有

      • $B=1$ → 1.2 GB

      • $B=8$ → 9.6 GB

      • $B=16$​ → 19.2 GB

梯度累积(Gradient accumulation)

  • 随着 batch size 变大,激活显存线性爆炸,这是由于同时驻留在显存中的激活过多

  • 而梯度累积的核心思想是把一个大 batch 拆成多个小 micro-batch,一次只在显存里放一个 micro-batch 的激活,但在数学上等价于用大 batch 更新参数

  • 相关术语

    • micro-batch size(mbs):每一次 forward / backward 实际送进模型的 batch 大小

    • global batch size(gbs):两次 optimizer step 之间,模型“等效看到”的总 batch 大小

    • grad_acc(gradient accumulation steps):在执行一次 optimizer.step() 之前,累计多少次梯度

    • 显然,gbs = mbs × grad_acc

  • 梯度累积示意图

image.png

  • 梯度累积的过程

    • 假设:$\text{mbs}=2$、$\text{grad_acc}=4$、$\text{gbs}=8$

    • 第 $i$ 次 micro-batch:forward(仅 2 个样本),backward,得到梯度 $g_i$​,但是不更新参数

    • 优化时,实际使用的是平均梯度

      gˉ=14i=14gi\bar{g} = \frac{1}{4} \sum_{i=1}^4 g_i
    • 这样,在任意时刻,显存里只有一个 micro-batch 的激活

    • 只要 $\text{mbs}$​ 固定,激活显存就是常数

  • 梯度为什么要“取平均”而不是“取和”?

    • 如果不取平均:梯度会随 grad_acc 线性放大,等价于隐式改变 learning rate

    • 取平均后:

      gˉ=Ebatch[L]\bar{g}=\mathbb{E}_{\text{batch}}[\nabla \mathcal{L}]
    • 与 batch size 无关,学习率无需调整

  • 两种技术的对比

    • activation recomputation 控制每个样本的激活大小,降低常数项

    • gradient accumulation 控制同时在显存中的样本数,降低 $B$

    • 二者通常一起使用

  • 与激活重计算相同,梯度累积的真实代价也是计算量(FLOPs)

    • forward / backward 次数增加:

      computegrad_acc\text{compute} \propto \text{grad\_acc}
    • wall-clock time 增加,但可以换来更大等效 batch、更稳定的优化、显存可控

  • 自然而然地,不同的 micro batch 的前向和后向传播是可以并行运行的,它们彼此之间相互独立,唯一的区别只是输入样本不同,而这正是一个信号——可以把训练扩展到多张 GPU 上

Checkpoint

  • Checkpoint 的思想很简单,设 Transformer Layers 的总层数 $L$,每 $k$ 层存一个 checkpoint

  • 那么在 backward 时,每个区块长度为 $k$,需要重新跑 $k-1$ 层 forward

  • FlashAttention 控制 attention backward 重算 $QK^\top$,而 checkpoint 控制 FFN / LN / MLP 重算

  • 这意味着 backward 阶段几乎全在计算,而非读写激活值

  • 在现代 GPU(A100 / H100)上,计算比读写更便宜,所以 DeepSpeed、HF Trainer 等两者默认启用

优化后的激活显存

  • 统一符号:

    • $B_{\text{micro}}$:micro-batch size

    • $B_{\text{global}}$:global batch size

    • $N_{\text{acc}}$:gradient accumulation steps

    • $S$:sequence length

    • $d$:hidden size

    • $H$:attention heads

    • $L$:Transformer layers

    • $B = B_{\text{global}} = B_{\text{micro}} \times N_{\text{acc}}$

  • 不使用任何优化(baseline)的 activation 显存(每一层)

    • token-wise 激活:

      O(B×S×d)O(B \times S \times d)
    • attention softmax:

      O(B×H×S2)O(B \times H \times S^2)
    • 因此总 activation memory:

      MactL(BSd+BHS2)M_{\text{act}} \sim L \cdot \Big( B \cdot S \cdot d + B \cdot H \cdot S^2 \Big)
    • batch、$S^2$、layer 数全部耦合在一起,最容易炸显存

  • 加上FlashAttention(去掉 $S^2$)

    • FlashAttention 的本质是不保存 $P=\text{softmax}(QK^\top)$,backward 时重算

    • 于是 $BHS^2$ 级别的 activation 被消除,只剩 token-wise 激活

      MactL(BSd)M_{\text{act}} \sim L \cdot \big( B \cdot S \cdot d \big)
    • 但此时 $B = B_{\text{global}}$,所以 batch size 一大,activation 仍然线性爆炸

  • 加上 gradient accumulation(把 $B_{\text{global}}$ 变成 $B_{\text{micro}}$)

    • gradient accumulation 的关键作用是:显存中同时存在的 batch = micro-batch

    • 于是 $B \Rightarrow B_{\text{micro}}$,公式变为

      MactL(BmicroSd)M_{\text{act}} \sim L \cdot \big( B_{\text{micro}} \cdot S \cdot d \big)
    • global batch size 与 activation 显存彻底解耦

  • 加上 checkpoint(降低 layer 维度)

    • checkpoint 的本质是:不存每一层的 activation,只存少数“锚点层”

    • 假设每 $k$ 层存一个 checkpoint,需要存的层数约为 $\frac{L}{k}$

    • 则 activation 显存变为:

      MactLk(BmicroSd)M_{\text{act}} \sim \frac{L}{k} \cdot \big( B_{\text{micro}} \cdot S \cdot d \big)

训练所需的内存计算

  • 在训练神经网络时,占用 GPU 内存的内容主要如下

    • 模型权重参数:参数量乘以每个参数的字节数

      • FP32精度:$4N$​

      • BF16混合精度:$2N$​

      • 在BF16混合精度下,需要使用FP32额外存储模型参数副本(主权重),即 $4N$

      • 混合精度并不会凭空减少模型的总显存占用,因为为了数值稳定性,仍然需要保留 FP32 的权重和梯度,甚至在梯度用 FP32 累积时,参数相关显存还会略微增加,但混合精度依然是必须的,因为:FP16/BF16 可以使用 GPU 的高速低精度算子,显著加快训练;前向传播中的激活以半精度存储,极大降低了激活显存,而激活正是训练时的主要显存瓶颈

    • 模型梯度:参数量乘以每个参数的字节数

      • FP32 精度:$4N$

      • BF16混合精度:$2N$​

      • 在BF16混合精度下,需要使用FP32额外存储模型梯度副本,即 $4N$

    • 优化器状态:Adam优化器需要额外存储动量和方差,每个 4 字节,以实现数值稳定性,即 $8N$

    • 用于计算梯度的激活值

    • CUDA 内核一般需要占用 1-2 GB

  • 这些内容以张量的形式存储,并具有不同的形状和精度

    • 形状通常由批次大小、序列长度、模型的隐藏层维度、注意力头数、词汇表大小等决定

    • 精度指的是FP32、BF16或FP8等格式,分别需要4、2或1字节来存储张量中的每个值

PyTorch Profiler

  • PyTorch 的 profilerarrow-up-right 允许精确地追踪和可视化训练过程中 CPU 和 GPU 的动态变化过程

  • Profiler 会产生训练时的运行轨迹,可以在 TensorBoard 或 Chrome 的 trace viewer 中可视化,如

    • CPU 线程如何以异步方式向 GPU 发射(launch)kernel

    • 多个 CUDA stream 如何并行处理计算与通信

    • 各个 kernel 的执行时间以及显存分配情况

  • 这个运行轨迹有助于识别训练过程中的性能瓶颈,例如:

    • 本可以重叠执行却被串行化的计算与通信

    • GPU 因等待数据传输而产生的空闲时间

    • CPU 与 GPU 之间的 CUDA 同步与内存拷贝开销

    • GPU 上 kernel 启动本身带来的额外开销

  • 理解这些执行模式对于优化分布式训练性能至关重要

  • 例如,trace 能清楚地展示:梯度同步是否与反向传播计算正确地实现了重叠

数据并行(Data Parallelism)

基本思想

  • 在多张 GPU 上各自复制一份完整的模型(这些副本通常称为 model instances),然后让每张 GPU 并行地对不同的 micro-batch 执行前向和反向传播

  • 由于并行化发生在数据维度上,因此被称为“数据并行”

  • 数据并行本质上就是梯度累积的并行化版本

image.png

  • all-reduce 操作

    • 由于每张 GPU 使用的是不同的 micro-batch,各个 GPU 上算出来的梯度自然也不同

    • 为了让所有 GPU 上的模型副本保持一致,需要在参数更新之前,对各个 GPU 的梯度进行平均

    • 上述这个过程通过 all-reduce 操作来完成,这是迄今为止的第一个分布式通信原语,负责在不同 GPU(甚至不同节点)之间进行梯度的同步与通信

    • 这个操作发生在反向传播阶段、optimizer step 之前

朴素数据并行

  • 实现步骤

    image.png - 等所有反向传播全部完成,得到完整梯度 - 再对所有 GPU 的梯度执行一次 all-reduce,同步梯度 - 最后进行参数更新

  • 存在的问题

    • 计算结束 → 再通信 是典型的反模式(A BIG NO-NO)

    • 原因是在 all-reduce 通信期间,GPU 基本处于空闲状态,白白浪费了大量算力

    • 因此正确思路是,尽可能让通信与计算重叠(overlap),让它们同时进行

    • 也即一边做反向传播计算,一边同步已经算好的梯度,这样 GPU 才不会因为等待通信而停下来

第一步优化:梯度同步与反向传播重叠执行

  • 在做反向传播计算的同时,就开始做梯度同步

  • 在反向传播中:最后一层的梯度最先算完,越靠前的层,梯度越晚算完

  • 因此某一层的梯度一旦计算完成,就已经可以立刻进行 all-reduce,而不必等待整个模型的 backward 结束

  • 在 PyTorch 中,可以通过给每个参数注册 backward hook 来实现:当某个参数的梯度刚刚计算完成,就立刻触发一次 all-reduce

  • 此时其他参数的梯度仍在计算,backward 仍在继续,梯度同步被“切碎”,并嵌入到 backward 过程中

  • 这样,大部分通信时间被隐藏在计算时间里,GPU 等待通信的空闲时间大幅减少数据并行的整体效率显著提升

image.png

第二步优化:梯度分桶(bucketing)

  • 在第一步优化中,每个参数的梯度一旦计算完成,就立刻触发一次 all-reduce

  • 这在逻辑上是对的,但在工程上存在问题:GPU 和通信库(NCCL)非常不擅长处理大量小张量

  • 具体代价包括:每次 all-reduce 都有固定启动开销,而频繁的小张量通信,kernel launch 多、NCCL 调度频繁、网络带宽利用率低,结果导致通信被“碎片化”了

  • 解决思路:不要对每个参数的梯度单独 all-reduce,而是把多个梯度拼在一起,一次性同步,这就是 gradient bucketing

  • 典型流程:

    • 框架预先定义若干个 bucket(按大小,比如 25MB)

    • 反向传播时:某个参数梯度计算完成后,会被拷贝/映射进对应的 bucket

    • 当一个 bucket 被填满,或其包含的梯度全部完成:立刻触发一次 all-reduce

    • backward 继续执行其他层

  • 从而变成 bucket 粒度的 overlap,而不是 parameter 粒度的 overlap

  • 其优点在于,减少 kernel / NCCL 启动次数、提高带宽利用率、具有更少的同步点(sync points)

第三步优化:梯度累积时延迟同步(no_sync)

  • 梯度累积的最后一次才会更新参数

    • 数据并行:每张 GPU 对自己的 micro-batch 做 forward/backward,然后用 all-reduce 同步梯度

    • 梯度累积:连续做多次 forward/backward,暂不更新参数,最后一次再 optimizer.step()

    • 如果把数据并行直接套在梯度累积上,最直接的实现是:每一次 backward,都自动触发一次 all-reduce

    • 但这是完全没有必要的,梯度只是被累加,并不会立刻用来更新参数,中间步骤的同步是纯粹的通信开销

    • 因此,梯度累积期间,不需要保持不同 GPU 之间的梯度一致

  • 正确做法:前 $K-1$ 次 backward 不做梯度同步,第 $K$​ 次 backward 正常 all-reduce

  • PyTorch DDP 提供了一个非常直接的机制:model.no_sync(),其语义是在这个作用域内,禁止 DDP 在 backward 时触发 all-reduce

  • 通过延迟同步,解决“梯度同步频率过高”的问题,降低通信开销

  • 通信缓冲区

    • NCCL / GPU DMA 更擅长处理连续内存,而非连续张量需要额外 copy,增加延迟

    • 因此在执行通信操作时,张量必须在内存中连续,以避免冗余的内存副本

    • 为了实现最佳效果,工程上通常会预先分配一大块激活大小的连续缓冲区,把梯度 / bucket 拷贝进去做通信

    • 虽然这提高了通信效率,但是这些通信缓冲区也会占用额外显存,部分提高训练期间的显存使用峰值

相关分析

  • 梯度累积与数据并行

    • 已知

      • 数据并行 GPU 数:$dp$

      • micro-batch size:$mbs$

      • 梯度累积步数:$grad_acc$

    • 于是等效的批次大小

      Global Batch Size=gbs=mbs×grad_acc×dp\text{Global Batch Size} = gbs = mbs \times grad\_acc \times dp
    • 假设目标是一个固定的 $gbs$

      • 可以选择用更多 GPU(更大的 $dp$)或者更多梯度累积步数(更大的 $grad_acc$​)

      • 工程上,更倾向于先使用更多的 GPU,这是由于

      • 数据并行:不同 GPU 同时做 forward/backward,wall-clock 时间显著下降,是真正的并行加速

      • 而梯度累积:多次 forward/backward 必须顺序执行,只省显存,不省时间,甚至会拉长一次 optimizer step 的耗时

      • 因此工程经验是:能用数据并行解决的 batch size,绝不用 grad_acc;只有当 GPU 不够用时,才增加 grad_acc

    • 通过二者配合,activation 显存 ∝ $mbs$,通信频率 ∝ $\frac{1}{K}$(有 no_sync),参数更新频率 ∝ $\frac{1}{K}$

  • 大规模情况下的环延迟问题

    • 当数据并行规模扩展到 512 张 GPU 甚至更多 时(具体阈值取决于网络拓扑和带宽),梯度同步中的 all-reduce 通信会出现一个新的瓶颈

    • 不再是带宽受限,而是被 ring latency(环延迟)限制

    • 在 ring all-reduce 中:GPU 被组织成一个逻辑环,梯度数据需要在环上逐跳传播,即使数据量不大,也必须等信号绕完整个环一圈

    • 当 GPU 数量很大时:环变长,单次 all-reduce 的最小耗时被“信号传播时间”锁死

    • 这部分时间:无法通过并行计算掩盖,无法被 backward 继续重叠

  • 大规模情况下“通信与计算重叠”开始失效

    • 在中小规模 DP 下:backward 计算时间足够长,all-reduce 可以“藏”在计算后面,GPU 几乎不空闲

    • 但在超大规模时:单卡计算量 ∝ $\frac{1}{dp}$,通信延迟 ≈ 常数或缓慢下降

    • 于是出现:通信时间 > 可重叠的计算时间

    • 结果是:backward 结束了,通信还没结束,GPU 被迫等待

  • 大规模情况下的后果

    • 计算效率下降:GPU 有更多空转时间,FLOPs 利用率下降

    • 吞吐量下降:每增加一张 GPU,系统总吞吐的提升越来越小,甚至可能变差

    • 这也就是典型的规模扩展的边际收益递减

    • 这是因为数据并行的通信模式是:所有 GPU ↔ 所有 GPU,每一步都要全局同步梯度

    • 当规模极大时通信同步本身成为主导成本,是由通信拓扑和物理网络决定的根本限制

DeepSpeed ZeRO

基本思想

  • 在标准数据并行中,每张 GPU 都有:

    • 一份完整模型参数

    • 一份完整梯度

    • 一份完整 optimizer state

  • 而 DP 扩展的关键误区是:“用更多 GPU 做数据并行,单卡显存会变小”

  • DP 只分数据,不分模型状态,所以 GPU 数量变多,但是单卡显存不变,甚至通信更繁杂,这就是为什么 DP 无法训练超出单卡显存上限的模型

  • 这就是从 DP 走向 ZeRO 的本质转折:为什么所有 GPU 都要保存同一份模型状态?

  • ZeRO 的全称是 Zero Redundancy Optimizer,意思是在数据并行组内,不再冗余存储模型状态,每个 GPU 只保留一部分

  • ZeRO 的不同阶段

    • ZeRO Stage 0:标准的数据并行

    • ZeRO Stage 1:划分 optimizer state,这是由于 optimizer state 占用最大(Adam),因此 optimizer state 在 DP 组内按参数分片,每张 GPU 只保存 $\frac{1}{dp}$ 的 optimizer state,但参数、梯度仍然是完整的

    • ZeRO Stage 2:再划分梯度,梯度也按参数分片,每张 GPU 只保存 $\frac{1}{dp}$ 的梯度,此时 backward 后不再做完整 all-reduce,而是 reduce-scatter + all-gather

    • ZeRO Stage 3:最后划分参数,这是最激进的一步,参数本身也被分片,单张 GPU 上没有完整模型,在 forward / backward 时用到某一层参数就临时 all-gather,用完立刻释放

  • ZeRO 不同阶段的显存占用

    • 回顾模型训练所需的显存占用

      • 模型参数(半精度,即 BF16 / FP16):$2\Psi$

      • 模型梯度(半精度,即 BF16 / FP16):$2\Psi$

      • 模型参数的 FP32 主副本以及优化器状态:$4\Psi + (4\Psi + 4\Psi)$

      • 模型梯度的 FP32 版本:$4\Psi$​(可选项,仅在需要以 FP32 累积梯度时才会存在)

    • 那么 ZeRO 不同阶段对显存占用的优化如下($N_d$ 表示 DP 度,即数据并行的规模)

    zero_memory.svg

ZeRO Stage 1:切分优化器状态(Partitioning Optimizer States)

  • 基本思想

    dp_zero1.gif - 在普通的数据并行(vanilla DP)中,所有 DP rank 在反向传播结束后都会拿到完全相同的梯度,并且各自独立地执行一模一样的优化器更新,这显然存在大量重复计算和重复存储,那么能不能既避免这些重复工作,又同时降低显存占用? - 在 ZeRO-1 中,优化器状态(optimizer states)被切分成 $N_d$ 份,这意味着每个 DP rank 只保存 $\frac{1}{N_d}$ 的优化器状态,在一次 optimizer step 中:只有 $\frac{1}{N_d}$​ 的 FP32 参数会在该 rank 上被更新 - 但问题是前向传播时,每个模型副本仍然需要完整的一套参数,因此,在 ZeRO-1 中,必须在 optimizer step 之后,额外加入一次 all-gather 操作——把各个 rank 上更新好的参数切片重新收集,使每个 DP rank 再次拥有完整、同步的模型参数

  • 优化器状态只在“参数更新(optimizer step)”阶段用到,而不是在反向传播阶段

    • 反向传播(backward):输入:激活值、参数,输出:梯度 $\nabla W$,不使用优化器状态

    • 参数更新(optimizer step):输入:梯度($\nabla W$,FP16 / FP32)、优化器状态(如 Adam 的 $m, v$)、FP32 主参数,输出:更新后的 FP32 参数、更新后的优化器状态

    • 以 AdaW 为例

      mt=β1mt1+(1β1)Wtvt=β2vt1+(1β2)(Wt)2Wt=Wt1ηmtvt+ϵ\begin{align} m_t &= \beta_1 m_{t-1} + (1-\beta_1)\nabla W_t \\ v_t &= \beta_2 v_{t-1} + (1-\beta_2)(\nabla W_t)^2 \\ W_t &= W_{t-1} - \eta \frac{m_t}{\sqrt{v_t}+\epsilon} \end{align}
  • 具体训练步骤

    • 前向传播:每个 DP rank 使用完整的 BF16 参数,但处理的是不同的 micro-batch

    • 反向传播:每个 DP rank 计算基于各自的数据计算梯度

    • All-reduce 梯度:all-reduce(或等价的 reduce-scatter + all-gather),每个 rank 拿到完整梯度

    • 局部优化器更新:每个 rank 只更新 $\frac{1}{N_d}$ 的优化器状态,得到对应的 $\frac{1}{N_d}$ FP32 参数,再转换成 $\frac{1}{N_d}$​ BF16 参数

    • All-gather 参数:对 BF16 参数执行 all-gather,把缺失的参数切片同步回每个 rank(这是 ZeRO 新增的通信步骤,vanilla DP 中不存在)

  • 在通信与计算重叠方面,主要有两种策略

    dp_zero1_overlap.svg - 在 optimizer step 中重叠:每更新完一部分参数,就立刻开始 all-gather,与后续参数更新并行 - 在 forward 中重叠:将某一层参数的 all-gather,与该层的前向计算重叠 - 但这些优化实现复杂,需要精细的 hook / bucket 管理,实践中通常直接使用 PyTorch 的 FSDP / ZeRO-3 实现,并把整个模型作为一个 FSDP unit - 在 ZeRO-1 中,优化器状态已经被切分,每个 rank 实际只更新 $\frac{1}{N_d}$ 参数,那么既然只需要一部分梯度来更新,为什么还要在每个 rank 上保留完整梯度?——这也就是 ZeRO Stage 2 的出发点

ZeRO Stage 2:加入梯度切分(Adding Gradient Partitioning)

  • 基本思想

    dp_zero2.gif - 既然在每个 DP rank 上,只需要与其本地优化器状态分片对应的那部分梯度,那么也可以像切分优化器状态一样,把梯度本身也切分 - 因此,在 ZeRO-2 中,反向传播时不再对梯度执行 all-reduce,而是直接执行 reduce-scatter,每个 rank 只保留 $\frac{1}{N_d}$ 的梯度分片 - 如果使用 FP32 梯度累积(即 BF16 梯度先转换并累积到 FP32),每个 rank 只需要保留 $\frac{1}{N_d}$ 的 FP32 梯度 - 这些 FP32 梯度来自 reduce-scatter 后的 BF16 梯度,只用于更新本地的优化器状态分片 - 在 optimizer step 中,这 $\frac{1}{N_d}$​​ 的 FP32 梯度仅用于更新本 rank 持有的优化器状态和对应的参数分片

  • 优化器状态只需要分片对应的那部分梯度的数学推导

    • 设参数向量 $W$ 被切分为:

      W=i=1NdW(i)W = \bigcup_{i=1}^{N_d} W^{(i)}
    • Adam / SGD 的更新是逐参数独立的

      W(i)W(i)ηf(W(i),opt_state(i))W^{(i)} \leftarrow W^{(i)} - \eta \cdot f(\nabla W^{(i)}, \text{opt\_state}^{(i)})
    • 因此,对于rank $i$,只更新 $W^{(i)}$,只需要 $\nabla W^{(i)}$,只需要 $\text{opt_state}^{(i)}$

  • 与 ZeRO Stage 对比:

    • 在通信层面,ZeRO-2 与 ZeRO-1 几乎完全一致,都包含 reduce-scatter(梯度)、all-gather(参数)

    • 唯一的区别在于 ZeRO-2 会在通信完成后立即释放不再需要的梯度内存,梯度从“逻辑切分” → “物理只存在一份”

    • 因此可以说:在通信角度,ZeRO-2 等价于 vanilla DP / ZeRO-1,在显存角度,ZeRO-2 明显优于 ZeRO-1

    • 除了实现复杂度更高之外,ZeRO-2 相比 ZeRO-1 几乎没有额外运行时开销,因此在实践中 ZeRO-2 通常是比 ZeRO-1 更优的选择

  • 通信与计算重叠

dp_zero2_overlap.svg

ZeRO-3:引入参数分片(FSDP)

  • 基本思想

    • 在 Stage 3 中,将前面在 DP 副本之间分片优化器状态和梯度的做法,进一步扩展为对模型参数本身进行分片

    • PyTorch 对这一阶段的原生实现称为 FSDP(Fully Sharded Data Parallelism)

  • 参数被分布式存储时,如何进行前向 / 反向传播?

    • 解决方法:在需要时按需聚合(all-gather)参数

    • 在执行前向传播、按顺序遍历各层时,会按需取回该层所需的参数,并在使用完之后立即将其从显存中释放

      dp_zero3_fwd.svg

    • 在反向传播中类似,只是流向相反

      dp_zero3_bwd.svg

  • 通信开销问题

    dp_zero3_overlap.svg - 通信代价分析 - 在一次训练 step 中,需要在 forward 和 backward 的过程中持续进行 all-gather 操作 - 相较于 ZeRO-2,这会在一次训练 step 中额外引入 $2\cdot num\_layers−1$​ 次 all-gather 操作,每一次 all-gather 都带来一个小的基础通信延迟开销 - 在前向传播中,当某一层需要参数时,进行一次 all-gather,这会带来一次 $\Psi$ 级别的通信开销 - 由于在 forward 后我们立刻丢弃了该层参数,在反向传播时,还需要再次 all-gather 参数,又引入一次 $\Psi$ 的通信开销 - 最后梯度仍然需要像 ZeRO-2 一样进行 reduce-scatter,这同样需要 $\Psi$​ 的通信开销 - 因此,ZeRO-3 的总通信代价为 $3\Psi$,而 ZeRO-2 的通信代价为 $2\Psi$ - 通信与计算重叠 - 这听起来似乎是非常大的通信开销,但实际上影响并不严重,因为可以通过通信与计算重叠来隐藏大部分通信时间,这种技术称为预取(prefetching) - 在执行第 $n$ 层的前向传播时,同时 all-gather 第 $n+1$ 层的参数 - 在执行 第 $n$ 层的反向传播时,同时 all-gather 第 $n-1$ 层的参数 - 只要 DP 规模不要过大,这种重叠就可以有效发挥作用,从经验来看,DP 不应超过 512

相关分析

  • DP 和 ZeRO 的局限性

    • 可以看到,激活值不能像参数、梯度、optimizer state 那样做 shard(分片),原因不是工程上“没实现”,而是在数据并行中,激活值天生就不是冗余的

    • 每个 GPU 拿到的是不同的 micro-batch,不同 GPU 上的激活值数值不同、语义不同,它们不是“同一份数据的多份拷贝”,而 sharding 的前提是有重复

    • 而激活值恰恰又会随着批次的增加而线性增长,这是迫切需要被解决的

    • ZeRO Stage 3 的本质是复制计算,切分存储,每个 GPU 执行完整模型的 forward / backward,但只存自己负责的一小片参数 / 梯度 / optimizer state,因此 forward 需要 all-gather 参数,backward 需要 reduce-scatter 梯度,这也就导致 ZeRO-3 的核心成本不是计算,而是“参数通信”

    • 与 ZeRO 的核心思想不同,张量并行(Tensor Parallelism)的出发点是切分计算本身

    • TP 不是“减少冗余副本”,而是把一个 layer 的数学运算本身拆到多个 GPU 上做,每个 GPU 只计算模型的一部分,只持有这一部分对应的参数,激活值也只是一部分

配置说明

  • 配置文件示例

  • 参数说明

    参数名称
    可选值/类型
    示例
    含义

    train_batch_size

    整数

    32

    每个训练步骤的批量大小

    gradient_accumulation_steps

    整数

    1

    每隔多少步进行一次梯度更新(梯度累积)

    steps_per_print

    整数

    200

    每多少步打印一次日志信息

    fp16.enabled

    布尔值

    true

    启用混合精度训练(FP16)

    fp16.loss_scale

    整数

    0

    动态损失缩放因子,通常设置为 0 表示自动选择

    fp16.initial_scale_power

    整数

    16

    初始损失缩放的幂

    zero_optimization.stage

    0, 1, 2, 3

    2

    Zero 优化阶段,0 表示没有优化,1 表示优化参数,2 表示优化梯度,3 表示优化参数和梯度

    zero_optimization.offload_param.device

    cpu, nvme

    cpu

    参数卸载到的设备类型,可以选择 cpunvme

    zero_optimization.offload_optimizer.device

    cpu, nvme

    cpu

    优化器状态卸载到的设备类型

    optimizer.type

    字符串(如 Adam, AdamW 等)

    Adam

    使用的优化器类型

    optimizer.params.lr

    浮动数值

    3e-5

    学习率

    optimizer.params.betas

    数组

    [0.9, 0.999]

    优化器的 beta 参数

    optimizer.params.eps

    浮动数值

    1e-8

    优化器的 eps 参数

    scheduler.type

    字符串(如 WarmupLR, CosineAnnealingLR 等)

    WarmupLR

    学习率调度器类型

    scheduler.params.warmup_min_lr

    浮动数值

    1e-7

    学习率的最小值(用于预热阶段)

    scheduler.params.warmup_max_lr

    浮动数值

    3e-5

    学习率的最大值(用于预热阶段)

    scheduler.params.warmup_num_steps

    整数

    500

    预热阶段的步数

    wall_clock_breakdown

    布尔值

    false

    是否打印每个阶段的壁钟时间分解

FSDP(Fully Sharded Data Parallel)

  • 基本原理

    • 一种参数分片技术,为超大规模模型的分布式训练设计

    • 参数分片:模型的参数被分片存储到不同的 GPU,每个 GPU 仅存储自己负责的参数分片(Sharded Parameters)

    • 数据分发:与 DDP 类似

    • 前向传播:FSDP 将需要的参数分片加载到显存,在完成计算后,卸载这些参数以节省显存

    • 后向传播:梯度计算完成后,梯度也被分片,并通过 AllReduce 操作同步到所有 GPU 优化器分片:优化器状态(如动量)也被分片存储并在必要时同步

  • 参数 ShardingStrategy 的不同取值决定了模型的划分方式

    • FULL_SHARD:将模型参数、梯度和优化器状态都切分到不同的 GPU 上,类似 ZeRO-3

    • SHARD_GRAD_OP:将梯度、优化器状态切分到不同的 GPU 上,每个 GPU 仍各自保留一份完整的模型参数,类似 ZeRO-2

    • NO_SHARD:不切分任何参数,类似 ZeRO-0

  • 配置文件示例

  • 参数说明

    参数名称
    可选值
    示例
    含义

    world_size

    正整数

    8

    分布式训练中的总进程数,通常等于设备数

    local_rank

    正整数

    0

    当前进程的本地进程号,用于标识分布式训练中的某个节点

    shard_optimizer_state

    truefalse

    true

    是否对优化器状态进行切分以节省内存

    mixed_precision

    truefalse

    true

    是否启用混合精度训练,以提高训练速度并减少内存占用

    fp16

    truefalse

    true

    是否启用 16 位浮动点数精度 (FP16),在一些硬件上可以显著提升训练速度

    activation_checkpointing

    truefalse

    true

    是否启用激活检查点,以减少显存使用

    device

    cudacpu

    cuda

    训练使用的设备类型,通常为 cuda(GPU)或 cpu

    offload_params

    truefalse

    true

    是否将模型参数卸载到 CPU 内存而不是 GPU 显存

    offload_optimizer_state

    truefalse

    false

    是否将优化器状态卸载到 CPU 内存而不是 GPU 显存

    use_reentrant

    truefalse

    true

    是否使用重入式 (reentrant) 机制,这对于某些特定的训练优化是必需的

    model_parallel_size

    正整数

    2

    模型并行的大小,指定了在多个设备之间划分模型的数量

    checkpoint_interval

    正整数

    1000

    每隔多少步进行一次检查点保存

    checkpoint_dir

    字符串(路径)

    ./checkpoints

    存储检查点文件的目录路径

张量并行(Tensor Parallelism)

基本思想

  • 张量并行(Tensor Parallelism,TP)是一种能够对权重、梯度、优化器状态,甚至激活值进行分片的并行算法,而且不需要在计算前将它们全部聚合

  • 张量并行利用了矩阵乘法 $A \times B$ 的数学性质

    • 按列拆分 $B$:

      AB=A[B1  B2  ]=[AB1  AB2  ]A \cdot B = A \cdot [B_1\;B_2\; \cdots ] = [AB_1 \;A B_2\; \cdots ]
    • 按行拆分 $A$:

      AB=[A1A2][B1B2]=i=1nAiBiA \cdot B =\begin{bmatrix} A_1 \\ A_2 \\ \vdots \end{bmatrix} \begin{bmatrix} B_1 & B_2 & \cdots \end{bmatrix} =\sum_{i=1}^{n} A_i B_i
  • 这意味着,可以通过两种方式来计算矩阵乘积:

    • 逐列地对 $B$ 的各个列块分别进行乘法

    • 或者逐行地对 $A$​ 的各个行块分别计算,然后将结果相加

  • 在神经网络中,矩阵乘法通常写成如下形式:

    X×WX \times W
    • 其中,$X$ 表示输入或激活值(activations),$W$​ 表示线性层(Linear layer)的权重矩阵

  • 一个简单的例子

TP diagram

并行化矩阵乘法运算

  • 在张量并行中,沿着某一个维度将张量切分成 $N$ 个分片(shards),并分布到 $N$ 张 GPU 上

  • 矩阵既可以按列切分,也可以按行切分,因此形成了列并行(column parallelism)和行并行(row parallelism)两种方式

  • 列并行(Column-wise / Column-linear Sharding)

    image.png - 第一种方式是按列分片(也称 column-linear) - 将完整的输入矩阵 $X$ 复制到每一个 worker 上 → 这需要一次 broadcast - 将权重矩阵 $W$ 按列切分,每个 worker 使用完整的 $X$ 与本地的部分权重 $W_i$ 相乘 - 最后使用 all-gather 将各个 worker 的输出拼接成完整结果

  • 行并行(Row-wise / Row-linear Sharding)

    image.png - 第二种方式是按行分片(也称 row-linear) - 将权重矩阵 $W$ 按行切分,同时也必须将输入 $X$ 进行相应切分 → 这里不再使用 broadcast,而是使用 scatter (这是我们遇到的第四种通信原语) - 每个 worker 计算本地部分的结果 - 各个 worker 的输出形状已经正确,但需要求和得到最终结果 → 使用 all-reduce

Transformer Block 中的张量并行

  • MLP 层的并行化

    image.png - MLP 通常是:

    XW1HactivationW2YX \xrightarrow{W_1} H \xrightarrow{\text{activation}} \xrightarrow{W_2} Y
    • 第一层($W_1$):Column Parallel

      • hidden dim 通常很大(如 4×)

      • 切输出最自然

      • 得到的是 sharded activation

      • 在实际训练中,broadcast 通常并不需要,因为可以保证输入在各个 TP rank 上本来就是同步的

    • 第二层($W_2$):Row Parallel

      • 输入正好是上一步的 sharded hidden

      • 每个 GPU 只处理自己的那一块

      • 最终 all-reduce 得到完整输出

    • 如果用 Row → Column:

      • 中间多一次 all-reduce

      • 直接增加通信在关键路径上的时间

    • MLP 的 TP 布局不是随便选的,是“通信最少”的结果

  • Multi-Head Attention 的并行化

    image.png - 在注意力模块中,可以采用类似的策略:Query($Q$)、Key($K$)、Value($V$)矩阵使用列并行,输出投影(output projection)可以视为行并行 - 对于多头注意力来说,列并行有一个非常自然的解释——每一张 GPU 负责计算一个或一部分 attention heads

  • 之所以能够如此高效地在 Attention 和 MLP 模块中应用张量并行,是因为它们都具有天然可独立分解的维度:

    • Attention 模块可以沿着 $num_attention_heads$ 维度并行(每个 attention head 彼此独立)

    • MLP 模块可以沿着 $hidden_dim$​ 维度并行(前馈网络在该维度上的计算彼此独立)

相关分析

  • 张量并行度的限制(Attention Heads)

    • 需要注意的是,张量并行的并行度不应超过 attention heads 的数量,因为我们是沿着 $num_attention_heads$​ 维度对 QKV 投影进行切分的

    • 在使用 GQA 时,有 $num_attention_heads$ 个 query heads,但只有 $num_kv_heads$​ 个 key/value heads,且满足

      num_attention_headsnum_kv_headsnum\_attention\_heads \ge num\_kv\_heads
    • 在这种情况下,理论上仍然可以设置 $TP = num_attention_heads$,但需要确保 K/V heads 在不同 GPU 之间保持正确同步

    • 例如,Llama-3 8B 模型有 32 个 query heads,但只有 8 个 key/value heads,因此,虽然 TP 理论上可以扩展到 32,但需要非常谨慎的实现来保证 K/V head 在各个 TP worker 之间的一致性

  • 张量并行中的前向传播瓶颈

    • 在张量并行的前向传播中,每一层 decoder 的前向传播中,都会遇到一个 all-reduce 的同步点,这个同步点无法与计算完全重叠

    • 这是因为在 LayerNorm 之前,必须先将各个 TP rank 上的部分结果合并

      • 前面的线性层 / Attention / MLP 是 TP 分片计算的,但在 LayerNorm 之前,必须通过 all-gather / all-reduce,使得每个 TP rank 都拿到完整的激活值

      • 也就是说,对任意 TP rank $r$:

        x(r)=x(逐元素完全相同)x^{(r)} = x \quad \text{(逐元素完全相同)}
      • 这是 LayerNorm 的一个硬约束:它需要对完整 hidden dimension 做统计量(均值、方差),不能在 shard 上算

        LN(x)=γxμ(x)σ2(x)+ϵ+β\text{LN}(x) = \gamma \cdot \frac{x - \mu(x)}{\sqrt{\sigma^2(x) + \epsilon}} + \beta
    • LayerNorm 的前向一致 ⇒ 反向梯度天然一致

      • 因为每个 TP rank 的输入 $x$ 完全相同,forward 输出也完全相同,loss 也是相同的

      • 所以反向传播时,对 LayerNorm 参数的梯度:

        γ(r)=γ,β(r)=β\nabla_\gamma^{(r)} = \nabla_\gamma,\quad \nabla_\beta^{(r)} = \nabla_\beta
      • 对所有 TP rank $r$ 都成立

      • 因此 LayerNorm 的梯度在各 TP rank 上是“天然一致”的,不需要 all-reduce 来同步

    • Dropout 要“同步随机种子”

      • Dropout 本质上引入了随机性,其等价于:

        y=xm,mBernoulli(p)y = x \odot m,\quad m \sim \text{Bernoulli}(p)
        • 其中 $m$ 是一个随机 mask

      • 如果 TP rank 使用不同的随机种子,即便输入 $x$ 在 TP rank 上是一样的,但是 rank 0 用 mask $m^{(0)}$,rank 1 用 mask $m^{(1)}$,且 $m^{(0)} \neq m^{(1)}$,那么 $y^{(0)} \neq y^{(1)}$

      • 接下来会发生灾难性的后果——backward 梯度不同、LayerNorm 前提被破坏、参数更新不一致、训练数值不再等价于单卡

    • 一些系统(如 Megatron-LM、Nanotron)通过以下方式缓解这一问题:

      • 在 FC1 计算过程中部分重叠 all-gather

      • 当矩阵乘法的一部分结果计算完成后,就立即将其发送到其他 GPU

      • 剩余部分仍在继续计算

  • 激活内存与通信的权衡

    • 引入额外的通信:ZeRO 的通信如 all-reduce / reduce-scatter,发生在 backward 阶段,可以和反向计算流水重叠;而张量并行在模型的前向计算路径(LayerNorm 之前)中直接引入了多种分布式通信原语,而这些通信很难被完全隐藏或与计算重叠

    • 激活内存的节省被破坏:张量并行确实可以将矩阵乘法中的中间激活值分片到多个 GPU,从而减少单卡激活内存占用;但仍然存在限制,对于 LayerNorm 等操作,仍然需要收集完整激活,因此并不能获得“完全”的激活内存节省

    • 吞吐率的下降:TP 引入了强依赖于网络条件的通信需求,由于某些 all-reduce 操作无法被完全隐藏,它们会直接落在前向传播的关键路径上——也就是决定前向传播最短完成时间的那条操作链

    • 因此,最终性能是以下因素之间的权衡结果:计算与显存节省带来的收益和额外通信开销带来的损失

  • 没有任何一种并行方式是万能的,现实最优解永远是:

    TP×DP×(ZeRO / FSDP)×Activation Checkpointing\text{TP} \times \text{DP} \times \text{(ZeRO / FSDP)} \times \text{Activation Checkpointing}
    • TP:解决 activation

    • ZeRO:解决参数 / optimizer

    • DP:提升吞吐

    • checkpoint:压 activation 峰值

序列并行(Sequence Parallelism)

基本思想

  • 序列并行(Sequence Parallelism,SP)是对张量并行中存在的问题的自然扩展,是为了处理张量并行无法覆盖的操作,比如 Dropout 和 LayerNorm

  • 与 TP 不同,SP 沿序列维度(sequence length)切分激活和计算,而不是沿隐藏维度,这样可以把原本依赖完整隐藏维度的操作(如 LayerNorm)所需的激活显存分摊到多张 GPU 上

  • 序列并行和上下文并行的区别

    • 这里的序列并行专门用于配合 TP,主要处理 Dropout 和 LayerNorm 操作

    • 在处理超长序列时,注意力计算会成为瓶颈,需要使用 Ring Attention 等技术,这类方法有时也被叫作序列并行,但为了区分,称其为上下文并行在后文介绍

  • 以 LayerNorm 为例解释为什么要沿序列维度处理

    LayerNorm(x)=γxμσ2+ϵ+βμ=mean(x),σ2=var(x)\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta\\ \mu = \text{mean}(x), \quad \sigma^2 = \text{var}(x)
    • $\mu$ 和 $\sigma^2$ 是沿隐藏维度 $h$ 计算的均值和方差

    • LayerNorm 需要访问完整的 hidden_dim才能计算正确

    • Dropout 也需要沿 sequence 保持 mask 的一致性或对应性

    • 因此,即使这些操作计算量很小,它们的激活仍然占用大量显存

  • 基本思想

    • SP 通过沿序列维度切分激活来分摊显存压力

    • 每个 GPU 只处理自己序列分片的 LayerNorm 或 Dropout

    • 激活显存需求从 $O(B \cdot S \cdot H)$ 被切分到 $O(B \cdot S/N_{seq} \cdot H)$,$N_{seq}$​ 是 SP 切分的 GPU 数

TP + SP 后的变化

in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter             in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather            SP region needs full hidden_dim

  • 首先,TP 中引入了两个操作标记 f 和 f*,它们是一对共轭操作(conjugate pair)

    Pass
    f
    f*

    Forward

    no-op(不做操作,因为激活已经在各个 TP rank 上重复)

    all-reduce(同步激活,保证计算正确)

    Backward

    all-reduce(同步梯度)

    no-op(梯度已重复)

  • 共轭对的含义:在每个传递中,当一个是 no-op,另一个是 all-reduce;在另一传递中刚好相反

  • SP 处理沿序列维度切分的操作(如 LayerNorm、Dropout):

    • 在 SP 区域,避免使用 all-reduce,因为 all-reduce 需要收集完整的激活,增加峰值显存,这不是 SP 的目的

    • g / g* 操作用于 TP ↔ SP 的转换:

      • g: all-gather → 把沿序列切分的激活拼接回来,供 TP 层使用

      • g*: reduce-scatter → 把 TP 输出切分回 SP 序列分片

TP+SP 的前向传播流程

image.png

  • 初始 LayerNorm(SP 区域)

    • 输入 X1*, X2*,形状 $(B, S/2, H)$,已沿序列切分

    • 每个 GPU 独立计算 LayerNorm → 得到 Y1*, Y2*

  • 第一次转换(SP → TP)

    • g 操作(all-gather) → 拼接 Y1* 和 Y2*

    • 恢复完整序列长度 $(B, S, H)$,供列线性层(column-linear)使用

  • 第一个线性层(TP 区域,列线性)

    • 输入 Y 被切分隐藏维度

    • 每个 GPU 独立做 GELU → 输出 Z1*, Z2*,形状 $(B, S, H/2)$

  • 第二个线性层(TP 区域,行线性)

    • 恢复隐藏维度

    • 输出 W1, W2 需要在 GPU 之间累加

    • 最终输出 $(B, S, H)$

  • 第二次转换(TP → SP)

    • g* 操作(reduce-scatter) → 切回序列维度

    • 输出 W1*, W2*,形状 $(B, S/2, H)$​,供下一个 SP 层使用

  • 最大激活显存降低:

    • TP 单独使用:某些地方需要存储 $(B, S, H)$ 的激活

    • TP + SP:每次激活只存部分序列或部分隐藏维度

      max activation=BSHtp\text{max activation} = B \cdot S \cdot \frac{H}{tp}
    • 显著减轻 GPU 显存压力

  • 前向传播中隐藏状态在 TP 与 SP 的变化

    • TP-only

      Region
      隐藏维度 h*
      序列维度 s*

      进入 TP(column-linear)

      h* 被切分 (weight_out 被切分)

      s* 完整

      TP 区域

      h* 被切分

      s* 完整

      退出 TP(row-linear)

      h* 恢复完整 (weight_out 全 + all-reduce 保证正确)

      s* 完整

      SP 区域

      h* 完整

      s* 完整

      • 隐藏维度 h* 被切分,序列维度 s* 保持完整

      • 在 row-linear 时,需要对隐藏维度做 all-reduce 以保证计算正确

      • SP 区域因为没有 SP,所以仍然存储完整激活

相关分析

  • TP + SP

    • SP 作用在序列维度,沿 s* 切分激活

    • 在 row-linear 时,由于不需要存储完整激活,因此只需要 reduce-scatter

    • TP 区域中,列线性需要完整序列,因此先 all-gather;行线性输出后再 reduce-scatter 回 SP 分片

    Region
    隐藏维度 h*
    序列维度 s*

    进入 TP(column-linear)

    h* 被切分 (weight_out 被切分)

    s* all-gather 到完整长度

    TP 区域

    h* 被切分

    s* 完整

    退出 TP(row-linear)

    h* 恢复完整 (weight_out 全 + reduce-scatter 保证正确)

    s* reduce-scatter 回切分

    SP 区域

    h* 完整

    s* 切分

  • Embedding 层的情况

    • Embedding 的输出在 TP-only 下仍然存储完整序列

    • TP+SP 下序列被切分,减少显存使用

    Region
    Vanilla TP
    TP + SP

    Embedding layer (row-linear,按词表切分)

    h*: 完整 (weight_out 全 + all-reduce) s*: 完整

    h*: 完整 (weight_out 全 + reduce-scatter) s*: reduce-scatter 切分

  • 使用 TP+SP 是否会比普通 TP 带来更多的通信开销?

    • 在使用普通 TP 的前向传播中,每个 Transformer block 需要两次 all-reduce 操作;而在使用 SP 时,每个 Transformer block 需要 两次 all-gather 和两次 reduce-scatter 操作

    • 因此,从表面上看,SP 的通信操作次数是 TP 的两倍

    • 但是,由 一次 all-reduce 操作可以分解为一次 all-gather 加一次 reduce-scatter,它们在通信成本上实际上是等价的

    • 同样的推理也适用于反向传播,因为在反向传播中只是使用了每种操作的共轭形式(即 no-op ↔ all-reduce,以及 all-gather ↔ reduce-scatter)

  • 通信与计算重叠问题

    tp_sp_overlap.svg - 在每一层中共有四次通信操作(Attention 两次,MLP 两次) - 和 vanilla TP 一样,TP+SP 也很难与计算进行有效重叠,这使得整体吞吐量高度依赖于通信带宽 - 因此和 vanilla TP 一样,TP+SP 通常只在单个节点内部使用 - TP 的通信在计算路径上,跨节点太贵 - 因此一般将 TP 的并行度限制在每个节点的 GPU 数量以内,例如 TP ≤ 8 - TP≤8:节点内通信(NVLink) - 带宽极高(600GB/s 量级),延迟极低 - all-reduce / all-gather / reduce-scatter 相对“便宜” - TP / TP+SP 的通信还能勉强忍受 - TP>8:跨节点通信(EFA) - 带宽骤降(~100–400Gbps),延迟显著上升 - TP 中的通信是在前向关键路径上的,无法像 ZeRO 那样大规模 overlap - 结果:每一层的通信都直接拉长 forward / backward 的 critical path,于是 per-GPU throughput 急剧下降

  • LayerNorm 的 all-reduce 操作

    • 在 SP 区域,每个 TP rank 只看到 不同的 sequence 子区间

    • LayerNorm 的参数($\gamma, \beta$)是对整个 hidden 维度共享,对 sequence 不共享

    • 因此 rank 0:用 $x[:, 0:s/2, :]$ 算出来一组梯度,rank 1:用 $x[:, s/2:s, :]$​ 算出来另一组梯度

    • LayerNorm 的参数是逻辑上全模型共享,物理上每个 TP rank 各自一份

    • 如果不做 all-reduce 操作,那么 $\nabla \gamma^{(0)} \neq \nabla \gamma^{(1)}$,参数会立刻发散

    • 所以必须进行全量规约,即$\nabla \gamma \leftarrow \text{AllReduce}(\nabla \gamma^{(i)})$

    • 由于LayerNorm 参数量极少:只有 $2h$($\gamma,\beta$),相比 QKV / MLP 的 $O(h^2)$,可以忽略不计

  • TP+SP 仍然存在的问题

    • 即便用了 SP,在 TP 区域(attention / MLP)仍然需要 $(b, s, \frac{h}{tp})$ 的激活值,因为 $s$ 是完整的

    • 当 $s = 32k / 64k / 128k$ 时,attention 激活(尤其是 QK / softmax)仍然爆炸

    • 这也是 TP+SP 不能解决“超长上下文”的本质问题,为此引出上下文并行(Context Parallelism)解决上下文太长的问题

    • 同时,由于TP 的通信在 forward / backward 关键路径,当模型太大时,TP 不得不跨节点,延迟和带宽无法忍受,吞吐量会断崖式下降,为此引出流水线并行(Pipeline Parallelism)解决模型太大的问题

上下文并行(Context Parallelism)

基本思想

  • TP+SP 无法应对超长上下文

    • 如前文所述,在 TP+SP 中,当 $s$ 变得非常大时,激活在 TP region 仍然线性随 $s$ 增长

    • 即使使用 activation recomputation,也必须在 layer 边界保存一些激活

    • 这些激活的内存下界是 $O(s)$​,无法再缩减

  • 上下文并行的本质是:把 SP 从“只在非 TP 模块用”,升级为“在整个模型中都用”

    • 即把输入序列沿 sequence 维切分到多个 GPU,每个 GPU 只负责一段上下文(context slice)

    • 这一点和 SP 类似,但作用范围扩大到了 Attention / MLP 本身

    • 这样,在 TP region 里:每个 GPU 处理的激活大小从 $(b, s, h_{tp})$ 变成 $(b, s / cp, h_{tp})$,激活显存随 $s$ 的增长被 $cp$​ 倍稀释

  • 对于 MLP、LayerNorm、Dropout、逐 token 的线性变换等模块,切分序列是免费的

    • 以上模块都有一个共同点:每个 token 是独立处理的

    • 因此当把 sequence 切分后,不需要权重通信(权重不切),不需要跨 GPU 交换激活,backward 后只需要像 DP 一样对梯度做一次 all-reduce,其通信成本非常低

  • 而对于 Attention 模块

    • 在注意力计算中,每个 token 都要访问所有 token 的 key/value(在 casual attention 中,至少需要关注之前的每个 token 的 key/value)

    • 当 sequence 被切分后,每个 GPU 只持有一部分 $Q$,但需要全局的 $K, V$

    • 如果直接 all-gather 全部 $K, V$,通信量是 $O(s \cdot h)$,显存和带宽都会直接爆炸

    • 因此 CP 的可行性完全取决于有没有一种方法能高效地交换 $K, V$​,而不是一次性 gather

  • 上下文并行与 FlashAttention

    • 二者都使用 online softmax,都避免保存完整 attention matrix

    • 区别在于:FlashAttention 是单 GPU 内的 kernel 级优化,上下文并行是跨 GPU 的并行策略

    • CP 关注的是“序列如何分布到多卡”,不是单卡算得多快

  • Ring Attention 的核心思想

    • K/V 不一次性 gather,而是在 GPU 之间按环形顺序流动

    • 每个 GPU 一边接收 K/V,一边立刻算 attention 的一部分

    • softmax 用 online 方式累计

朴素 Ring Attention

ring-attention.gif

  • 假设 4 张 GPU,序列长度 $s = 4$,每张 GPU 上 1 个 token,GPU $i$ 初始持有 $(Q_i, K_i, V_i)$

  • 每一轮(总共 4 轮),每张 GPU 做三件事:

    • 非阻塞地把当前的 $K, V$ 发送给下一个 GPU

    • 用当前持有的 $K, V$ 计算局部 attention:

      partiali=Softmax(QiKd)V\text{partial}_i = \text{Softmax}\left(\frac{Q_i K^\top}{\sqrt{d}}\right)V
    • 等待从上一个 GPU 接收到新的 $K, V$,进入下一轮

  • 这样,每张 GPU 每一轮都会“看到”新的 token 的 $K, V$

  • 4 轮之后,每个 $Q_i$ 都和所有 token 交互过,不需要任何时刻保存完整 $K, V$

  • 这就是 “Ring” 的来源:$K/V$​ 在 GPU 之间绕一圈

  • 在casual attention 中,Ring Attention 存在严重的负载不均衡

    cp_attnmask.svg - 因果注意力的 mask 是下三角矩阵,token $i$ 只能看 $[1, i]$ - 而 softmax 是按行算的,因此一行只要“所需的 K/V 都到了”,这一行就可以立刻算完,不需要等完整一圈,于是不同 GPU 的“等待条件”完全不同 - 如图中所示:GPU 1 很快空闲,GPU 2 少量工作,GPU 3、4 成为瓶颈,整体吞吐被最慢的 GPU 决定

Zig-Zag Ring Attention

cp_zigzagmask.svg

  • 核心思想

    • Zig-Zag 的目标是在不破坏 Ring Attention 的通信结构,不引入全量 K/V 存储,不牺牲 online softmax 的前提下,让每张 GPU 的计算量接近一致

    • causal attention 的计算量取决于 token 在序列中的“相对位置”,越靠前的 token,可见的 key 越少,而越靠后的 token,可见的 key 越多

    • Zig-Zag 的做法是不按 token 的自然顺序分配到 GPU,而是“前后交错”地分配

    • 每张 GPU 同时拿到一部分早期 token + 一部分晚期 token,平均下来,每张 GPU 需要算的 attention 行数和每行的长度都接近

    • 这一步不改变 attention 的数学定义,只改变 token 到 GPU 的映射

  • 每个 GPU 都需要从其他 GPU 那里获取信息

    • 即使做了 Zig-Zag,每个 token 仍然可能需要 attend 到序列中任意更早的位置,而这些位置被分散在所有 GPU 上

    • 所以不可避免地,每张 GPU 最终仍然需要来自所有其他 GPU 的 K/V 信息,区别只是“什么时候拿、拿多少、怎么拿”

  • Zig-Zag Ring Attention 的通信与计算重叠

    • all-gather

      cp_overlap_allgather.svg - 所有 GPU 同时执行一次 all-gather,每张 GPU 一次性拿到完整的 K/V - 需要更多的临时内存,因为每块 GPU 需要一次性存储所有的 K/V 对 - 通信一步完成,但内存开销更大

    • all-to-all (Ring)

      cp_overlap_all2all.svg - K/V 以 chunk 为单位在 GPU 间按环流动 - 每张 GPU 每次只多拿一小块 K/V,拿到就立刻参与 attention 计算 - 通信与计算分散且重叠,但通信次数更多,每一轮有固定的基础通信开销,实现复杂度更高,因此更适合超长序列

流水线并行(Pipeline Parallelism)

基本思想

  • SP 和 CP 存在的问题

    • SP 和 CP 在处理长序列时很有帮助,但如果内存问题的根本原因并不是序列长度,而是模型本身的规模,这些方法的帮助就非常有限

    • 对于大模型(70B+参数),仅仅是权重本身的大小,就已经可能超出单个节点上 4–8 张 GPU 所能承载的上限

    • 为了解决这个问题,需要引入另一种并行维度:流水线并行(Pipeline Parallelism,PP)

  • 基本思想

    • 流水线并行是一种概念上非常简单但威力强大的技术——将模型的层切分并分布到多张 GPU 上

    • 例如,如果有 8 张 GPU,可以将第 1–4 层放在 GPU 1 上,第 5–8 层放在 GPU 2 上,依此类推,这样一来,每张 GPU 只需要存储和处理模型的一部分层,大幅降低了单卡的参数显存需求

    • 虽然模型参数被很好地切分到了不同 GPU 上,但每张 GPU 上的激活值显存占用却保持不变——这意味着,流水线并行并不能节省激活值显存,这是因为在开始第一次反向传播之前,每张 GPU 都需要先执行 $PP$ 次前向传播。虽然每张 GPU 只负责 $1/PP$ 的层数,但在第一次反向传播开始之前,它需要处理 $PP$ 个 micro-batch,因此最终需要存储的激活值数量为

    PP×(activs/PP)activsPP×(activs/PP)≈activs
    • 也就是说,激活值显存需求与不使用流水线并行时大致相同

  • 这也引入了一种新的通信模式:不再像 ZeRO-3 的数据并行那样以通信的形式传递参数,而是以“流水线”的形式在 GPU 之间顺序传递激活张量

朴素实现

image.png

  • 基本思想

    • 把模型的层简单地分布到多个设备上——例如,第一个 GPU 负责前几层,第二个 GPU 负责模型的中间部分,依此类推

    • 此时,模型的前向传播就变成了一个 batch 的数据沿着模型深度被顺序地传递,依次使用每一个 GPU

    • 这种方式对互连带宽的需求非常低,因为只在模型被划分的中间点处以通信的形式传输中等大小的激活值。这和 TP 形成了鲜明对比,在 TP 中,通信会在每一层内部发生多次

  • 流水线的效率

    • 上图中的灰色部分通常被称为 “Bubble” 气泡,可以通过计算气泡导致的时间损失来量化流水线的效率

    • 设 $t_f$ 和 $t_b$ 分别表示单个 micro-batch 在单个流水线 stage 上的前向和反向计算时间(一个常见的简化假设是 $t_b \approx 2 \times t_f$)

    • 如果能够做到完美并行,那么理想的总时间为

      tid=tf+tbt_{id} = t_f + t_b
    • 然而,在朴素实现中,流水线 bubble 的存在会引入额外的时间开销

      tpb=(p1)×(tf+tb)t_{pb} = (p - 1) \times (t_f + t_b)
    • 其中 $p$ 是流水线并行的度数,也就是 GPU 的数量,这部分时间表示某些 GPU 在等待其他 GPU 计算完成

    • 将额外的 bubble 时间与理想时间的比值写成

      rbubble=(p1)×(tf+tb)tf+tb=p1r_{bubble} = \frac{(p - 1) \times (t_f + t_b)}{t_f + t_b} = p - 1
  • 朴素实现是顺序、依次的,而流水线并行的核心挑战,正是如何有效绕开这种顺序性,让所有 GPU 始终保持忙碌,而不会出现某一张卡在计算、其他卡在等待的情况

All forward, All backward(FAFB)

pp_afab2.svg

  • 通过增加 micro-batch 的数量 $m$,可以将 bubble 缩小 $m$​​​ 倍,从而缓解流水线带来的效率损失

    rbubble=(p1)×(tf+tb)m×(tf+tb)=p1mr_{bubble} = \frac{(p - 1) \times (t_f + t_b)}{m\times(t_f + t_b)} = \frac{p-1}{m}
  • 这种调度方式被称为 all forward, all backward(AFAB):先完成所有前向传播,再完成所有反向传播

  • 它的优点在于,前向和反向仍然各自保持顺序结构,整体训练代码的组织方式几乎不需要很大改动,因此实现起来非常简单

  • 然而,需要处理的还有激活值带来的显存开销,在 AFAB 这种实现中,必须一直把所有激活值保存在显存中,直到进入反向传播阶段,这会很快导致显存爆炸

One forward, one backward(1F1B)/ Llama 3.1 流水线

image.png

  • 1F1B 的核心思想是:在中间的稳定阶段中,计算过程是在一次前向传播和一次反向传播之间交替进行的,即尽可能早地开始反向传播

  • 1F1B 方法中 bubble 的大小和 FAFB 保持一致,因此从纯计算效率角度看,训练效率并没有得到显著提升

  • 但和 AFAB 不同的是,只需要为 $p$ 个 micro-batch 保存激活值($p$ 是流水线并行的度数),而不是为 $m$ 个 micro-batch 保存激活值($m$​ 为 micro-batch 总数),这显著缓解了 AFAB 中出现的激活显存爆炸问题

  • 正因为激活显存占用降低,就可以使用更多的 micro-batch,而这反过来又能够减小 bubble 的影响

  • 这种方案的复杂性在于前向和反向传播不再是干净、顺序的两个阶段,而是跨设备并行、交错执行的,这意味着必须在每个设备上独立调度从前向到反向的切换,而不能再像传统训练流程那样,在一个简单、统一的中央训练循环中完成控制

  • 1F1B 中 bubble 对性能的影响

    Throughput scaling of pipeline parallelism with varying micro-batch sizes - 在左图中,当 micro-batch 数量小于或等于流水线并行度减一($m = p - 1$)时,可以清楚地看到 pipeline bubble 的破坏性影响:吞吐率很低,而且随着 PP 度数增加进一步下降 - 右图展示了当 micro-batch 数量远大于流水线并行度(例如 $m = 32 \gg p - 1$​)时,低 PP 度数下的性能显著改善,但在 PP 度数非常大时,仍然会受到限制 - 在实际训练中,并不能无限制地增加 micro-batch 的数量来维持 $m \gg p - 1$,因为最终还要受到目标全局 batch size 的约束 - 当 PP 度数不断增加而 micro-batch 数达到上限时,bubble 的大小最终仍会按照线性关系增长:

    rbubble=p1mr_{bubble}= \frac{p-1}{m}
    • 一个有趣的现象是:在 micro-batch 数较小的情况下,从一个节点($p = 8$)扩展到两个节点($p = 16$)时,性能只下降了 14%

    • 相比之下,张量并行在类似的跨节点场景中出现性能下降大约 43%,可见流水线并行在跨节点网络带宽较低的场景中仍然保持较好的扩展性

Interleaving stages(交错流水线)

pp_1f1b_interleaved.svg

  • 基本思想

    • 1F1B 调度在显存占用上已经有明显改进,但并未有效降低 idle bubble 的大小,可以以引入额外的通信操作为代价解决这个问题

    • 前述方法都是沿着模型深度维度对模型进行“朴素切分”,例如把第 1–4 层放在第一张 GPU 上,把第 5–8 层放在第二张 GPU 上

    • 但可以用别的切分方式切分模型,比如把奇数层(1、3、5、7)放在第一张 GPU 上,把偶数层(2、4、6、8)放在第二张 GPU 上

    • 这可以理解为一种“循环流水线”,一个 micro-batch 在前向传播过程中,会在多张 GPU 之间来回流动,而不是像之前那样只线性通过一次

  • 这种方案需要额外的通信,因为同一个 micro-batch 在一次前向或反向计算中会多次经过同一张 GPU,而在之前的方案中只需要经过一次

  • 不过,每一次前向和反向计算的时间都会被缩短为原来的 $\frac{1}{v}$,其中 $v$ 是每张 GPU 上的 stage 数量(模型 chunk 数)

  • 对应的 bubble 时间和 bubble 比例可以写为:

    tpb=(p1)(tf+tb)vrbubble=(p1)vmt_{pb} = \frac{(p - 1) \cdot (t_f + t_b)}{v}\\ r_{bubble} = \frac{(p - 1)}{v \cdot m}
  • 由此,可以同时通过增加 micro-batch 数 $m$,以及增加交错阶段数 $v$,来进一步减小 bubble;同时从量化角度看,通信量也会随 $v$ 成比例增加,因此这本质上是一个 trade-off

    • $m = 1, v = 1$ 对应的是最朴素的流水线并行

    • $v = 1$ 对应 AFAB 或 1F1B

    • $v \neq 1$ 对应交错流水线

  • 交错流水线的调度策略

    pp_llama3.1_schedule.png - 在交错流水线方案中,调度本身也会变得更加复杂,因为在任意一张 GPU、任意一个时刻,都需要决定 - 优先让较早的 micro-batch 通过较后的层,从而尽快闭合前向和反向回路(depth-first,目标是尽快把 batch 跑完整个模型) - 还是优先让较晚的 micro-batch 通过较早的层,从而尽可能填满整个流水线(breadth-first,目标是最大化流水线利用率) - [Breadth-First Pipeline Parallelism](https://arxiv.org/abs/2211.05953)

Zero bubble / DualPipe

Zero bubble

  • 以一个线性层(矩阵乘法)为例,其反向传播实际上包含两个相互独立的计算:

    • 对输入的反向传播,记为 $B$(backward w.r.t. activations)

    • 对权重的反向传播,记为 $W$(backward w.r.t. weights)

  • 它们在依赖关系上并不对称:

    image.png - $B$ 的输出(即对输入的梯度)是下一层反向传播所必需的 - $W$​ 的输出(权重梯度)并不影响后续层的反向传播,只需要在 optimizer.step 之前完成即可 - $B$ 在计算图中有“向前”的依赖,而 $W$​ 没有

  • 那么,只要同一 stage 的 $B$ 已经完成,对应的 $W$​ 就可以被延后,在任意空闲时间执行

  • 即用 $W$ 去“填 bubble”,在传统的 1F1B 或交错流水线中,bubble 产生的原因是某些 GPU 在等前面的 stage 或后面的 stage 完成计算,时间被浪费掉了

  • Zero Bubble 的关键思想是:

    image.png - 把 backward 从一个粗粒度操作拆成 $B$ 和 $W$ - 保证 $B$ 按依赖顺序尽快完成 - 把 $W$ 插入到原本会空闲的时间片中执行,用来填补 bubble

DualPipe

image.png

  • DeepSeek 的 DualPipe 可以看作是 Zero Bubble 思路的工程级扩展

    • 在 Zero Bubble 中,只有一条从前到后的流水线,bubble 通过 $W$ 的重排来填充

    • 在 DualPipe 中,同时存在两条沿着 pipeline 维度传播的计算流,一条从前往后,一条从后往前,两条流在时间上进一步交错

    • 这样做的效果是,即使某一方向暂时没有可执行的 $B$,另一方向的 $B$ 或 $W$​ 也可能正好可以执行,GPU 的空闲时间被进一步压缩

    • DualPipe 比前面的所有方案都复杂,但本质仍然是同一个逻辑:把依赖强的操作($B$)尽量前推,把依赖弱的操作($W$)用于填充气泡

  • 这类“近零 bubble”方案在概念上很优美,但实现难度极高,原因主要有三点:

    • backward 被拆成非常细粒度的操作,调度空间巨大

    • 不同算子(attention、MLP、通信)的执行时间并不完全一致

    • 要真正最小化 bubble,本质上是一个调度优化问题

  • 在 Zero Bubble 论文中,作者提到需要:

    • 精确测量各类细粒度操作的耗时

    • 构建调度约束

    • 通过整数线性规划(ILP)或启发式算法求解近似最优调度

专家并行(Expert Parallelism)

什么是 MoE

ep_moe.png

  • MoE 的基本思想

    • 近年来,MoE 范式随着 GPT-4、Mixtral、DeepSeek-V3 / R1 等模型的出现而受到广泛关注

    • 其核心思想是:在每一层中,不再只使用一个前馈网络,而是使用多个并行的前馈模块(即多个“专家”),并通过路由器(router)将不同的 token 分发到不同的专家中进行处理

    • 这样,不同 token 可以被“差异化”地计算

  • 什么是 Expert Parallelism

    • MoE 层的结构天然适合在“专家维度”上做并行,这就是所谓的 Expert Parallelism

    • 这是因为,每个专家对应一个独立的前馈网络,各个专家之间在计算上互不依赖

    • 因此,实现 EP 时,只需要把不同专家的前馈网络放到不同的 GPU / worker 上,将 token 的隐藏状态路由到对应的专家

    • 与张量并行相比,EP 更加轻量,不需要拆分矩阵乘法,只涉及 token 隐状态的路由与通信

专家并行的难点

  • 专家并行的整体流程

    • 本地路由

      • 每个设备只对自己负责的一部分 token 计算路由决策,即判断每个 token 应该被分配给哪些专家

      • 门控网络通常在所有设备上是完整复制的,以保证路由决策的一致性和局部可计算性

    • All-to-All 分发

      • 路由决策确定后,若某些 token 被分配到位于其他设备上的专家

      • 这些 token 的隐藏状态需要通过 All-to-All 集体通信发送到对应的设备

    • 并行专家计算

      • 每个设备接收到属于本地专家的 token 后

      • 独立执行专家网络的前向和反向计算,生成对应的输出

    • All-to-All 聚合

      • 专家计算完成后,token 的输出结果需要再次通过 All-to-All 通信

      • 将结果发送回 token 最初所在的设备,以便继续后续模型层的计算

  • All-to-All 通信是专家并行的核心

    • 根据门控网络的输出,每个设备都可能需要向所有其他设备发送一部分 token,同时也需要从其他设备接收属于本地专家的 token

    • 这种通信模式不同于数据并行中的 All-Reduce,All-Reduce 是对称、规整的聚合通信,而 All-to-All 则是高度不规则、数据依赖路由结果的通信模式

    • 由于专家选择具有稀疏性和不均匀性

      • 通信通常伴随着较高的带宽需求和复杂的同步开销

      • GPU 本身并非为这种高度分支化的计算模式而设计

      • 频繁的跨设备数据交换往往使网络带宽成为主要性能瓶颈

  • Transformer 中包含 MoE 层的前向传播典型顺序

    • 输入 token 经过自注意力机制

    • 输入 token 经过第一个前馈网络(FFN)的输入投影部分

    • 门控网络为每个 token 计算专家分配结果

    • All-to-All 通信:将 token 分发到其对应专家所在的设备

    • 各专家网络对分配到的 token 执行计算

    • All-to-All 通信:将专家输出的 token 发送回其原始设备

    • 根据门控分数对来自不同专家的输出进行加权与聚合

    • 输出 token 经过第二个 FFN 层(输出投影)

    • 输出 token 进入下一层计算(如 LayerNorm、残差连接等)

  • 通信与计算的重叠

    • 在前向传播中,两次 All-to-All 通信(token 分发与结果回传)是最主要的重叠对象

    • 在反向传播中,梯度同样需要按照路由路径返回专家,这又引入了额外的 All-to-All 通信阶段,同样可以进行重叠

  • 实现通信与计算重叠的方法

    • 异步通信原语

      • 避免使用阻塞式通信调用,因为它们会在通信完成前停止后续执行

      • 优先使用非阻塞通信接口,例如使用非阻塞的 All-to-All 变体或更底层的点对点原语,如 isend / irecv,并配合显式同步

    • 典型的执行模式

      • 启动非阻塞通信操作(如 handle = isend / irecv)

      • 在通信进行的同时,执行不依赖通信结果的计算,例如:

        • 流水线并行中后续层的计算

        • 网络中并行分支或独立模块的计算

        • 反向传播时,MoE 层通信之前的梯度计算

        • 专家计算中可先处理本地已就绪的数据

      • 在需要使用通信结果之前,通过 handle.wait() 显式等待通信完成

    • 借助 CUDA 流实现更细粒度的并发

      • CUDA 流提供了在 GPU 上调度并发操作的机制

      • 不同流上的计算核与通信操作可能并行执行

      • 通信库通常使用独立的 CUDA 流完成数据搬运和网络传输

      • 通过精细管理流之间的依赖关系,可以让计算核与通信同时进行,从而最大化硬件利用率

      • 这种方式通常需要框架层的深度支持或在自定义算子与通信逻辑中直接操作 CUDA 流

工程实践

  • EP 为什么要和其他并行方式一起用

    ep_schema.png - 在实际训练中,EP 通常不会单独使用,而是与数据并行等方式结合 - 原因在于 EP 只作用于 MoE 层,对非 MoE 层(如注意力层、普通前馈层)不起作用,EP 也不会像上下文并行那样切分序列维度上的 token - 如果只使用 EP,那么所有 GPU 仍然需要对非 MoE 模块做完全重复的计算,效率并不高 - 因此,常见做法是用 EP 在专家维度上切分 MoE 层,用 DP 在 batch 维度上切分输入数据 - 这样可以同时避免专家冗余和数据冗余

  • 工程上的关键技巧

    • 为了让 EP 在实践中高效运行,模型设计上通常需要配合一些约束

    • 以 DeepSeek-V3 为例,路由器被限制为每个 token 最多只会被发送到 $M$ 个节点,其中 $M = 4$,这样做可以尽量让一个 token 的计算集中在少量节点上,减少跨节点通信带来的开销

并行策略总结

总览

  • 迄今为止,一共有五种用于扩展模型训练的并行策略,分别对应模型或数据的不同维度:

    • 数据并行(Data Parallelism, DP):沿 batch 维度切分

    • 张量并行(Tensor Parallelism, TP):沿隐藏维度切分

    • 序列 / 上下文并行(Sequence / Context Parallelism, SP / CP):沿序列长度维度切分

    • 流水线并行(Pipeline Parallelism, PP):沿模型深度(层)切分

    • 专家并行(Expert Parallelism, EP):沿专家(MoE experts)维度切分

  • 此外,还有三种可以与数据并行结合、用于降低显存占用的 ZeRO 策略:

    • ZeRO-1:在 DP 副本之间切分优化器状态

    • ZeRO-2:在 DP 副本之间切分优化器状态和梯度

    • ZeRO-3:在 DP 副本之间切分优化器状态、梯度和参数

    • 一个自然的问题是:这些并行方式和 ZeRO 策略之间,哪些可以高效组合,哪些组合起来收益有限甚至不划算?

流水线并行与 ZeRO-3 的对比

  • 相似点

    • 两者都将模型参数分布在多张 GPU 上

    • 都沿模型深度方向进行通信与计算

    • 在每个设备上执行的是“完整层级”的计算,而不是子层计算(不同于 TP 或 EP 那样对单层进行切分)

    • 这里说的“一层”只是为了简化描述,实际的切分单位可以是多层、单层,甚至层的一部分,取决于具体实现

  • 可以从以下几个方面理解两者的不同:

    ZeRO-3
    Pipeline Parallelism (PP)

    存储单位

    每个计算单元只存储层的一部分参数

    每个计算单元存储完整的一层(或若干层)

    通信内容

    权重参数(weights)

    激活值(activations)

    调度特性

    模型无关(model-agnostic)

    模型无关(model-agnostic)

    实现难点

    复杂:需要处理模型切分、参数预取与释放

    复杂:需要设计高效的流水线调度以减少空泡

    扩展偏好

    偏好大 micro-batch size(mbs)和长序列长度(seq_len),以隐藏通信开销

    偏好大梯度累积步数(grad_acc)以隐藏流水线空泡

    每个设备计算粒度

    子层级或参数分片

    整层或多层

    目标

    降低单节点显存占用,通过切分参数、梯度和优化器状态

    提高训练吞吐量,通过流水线并行覆盖层级计算

    是否常组合

    可与 PP 组合,但需显著增加全局 batch size

    可与 ZeRO-1/2 组合;与 ZeRO-3 组合较少,成本高

  • 是否需要组合 PP 和 ZeRO-3

    • 理论上,PP 和 ZeRO-3 是可以组合的,但实践中并不常见,原因在于:

      • 组合后需要显著增大全局 batch size 才能摊薄通信成本

      • 这会引入在 batch size、模型规模、网络带宽和训练效率之间的权衡

    • 如果确实要组合使用,通常需要让 ZeRO-3 在多个 PP micro-batch 期间尽量保持权重常驻显存,以减少不必要的重复通信

    • 相比之下,ZeRO-1 和 ZeRO-2 只涉及优化器状态和梯度,它们与 PP 的组合非常自然,几乎没有新的工程复杂度

    • 例如,DeepSeek-V3 的训练就采用了 PP + ZeRO-1

TP / SP 与 PP、ZeRO 的互补关系

TP & SP diagram

  • 张量并行(TP)通常与序列并行(SP)一起使用,并且可以自然地与 PP 或 ZeRO-3 组合

  • 原因在于 TP 利用了矩阵乘法的可分配性,权重和激活可以被切分后分别计算,再通过通信聚合结果

  • 实践中,TP 单独扩展存在两个主要限制:

    • 通信处于计算关键路径上,规模扩大到一定程度后,通信会主导整体时间

    • TP 不是模型无关的,需要非常仔细地处理激活切分方式

      • 有些地方沿隐藏维度切分(TP 区域)

      • 有些地方沿序列维度切分(SP 区域)

    • 这使得 TP 的实现对模型结构高度敏感,工程复杂度较高

  • 因此,实际中的组合策略是:

    • TP 主要用于节点内(intra-node)的高速通信

    • PP 或 ZeRO-3 用于跨节点(inter-node)的并行

      • PP:通信带宽需求相对较低

      • ZeRO-3:通信更容易与计算重叠

    • 设计重点在于:

      • 合理划分 GPU 分组

      • 让 TP 通信尽量限制在单节点内

      • 同时避免 TP 的扩展瓶颈

CP 和 EP

CP diagram

  • CP 的目标是解决超长序列训练的问题,通过在序列维度上切分激活:

    • MLP、LayerNorm 等模块可以直接在分片序列上独立计算

    • 注意力层需要通信,因为每个 token 都需要访问全序列的 key / value

    • 这一问题通过 Ring Attention 等通信模式解决,使通信与计算尽量重叠

  • CP 在以下场景尤其关键:

    • 序列长度达到 $128k+$ token

    • 即使使用全量激活重计算,单卡显存仍无法容纳注意力中间状态

EP diagram

  • EP 专门用于 MoE 模型训练,通过:

    • 将不同专家分布在不同 GPU 上

    • 在计算时将 token 动态路由到对应专家

    • 再将结果聚合回原计算流

    • 其主要通信开销来自 all-to-all 操作(token 分发与回收)

    • 但 EP 带来的好处是每个 token 只经过全部参数中的一小部分,模型容量可以大幅提升,而计算成本增长相对可控

    • 当专家数量非常大时(例如 DeepSeek-V3 使用了 256 个专家),EP 就变得尤为重要

  • 由于在输入处理方式上,EP 与 DP 都涉及 token 在 GPU 之间的分发,一些实现中会将 EP 视为 DP 的一种变体:

    • DP:所有 GPU 运行相同模型副本

    • EP:GPU 上是不同专家,通过路由而非复制来处理输入

    • 两者的关键差异在于是否存在“专家专用结构”

作用范围

  • 张量并行(TP)+ 序列并行(SP):影响整个模型的计算,通过切分权重和激活来实现分布式计算

  • 上下文并行(CP):主要作用于 注意力层,因为跨序列通信只在注意力计算中需要,其他模块(如 MLP、LayerNorm)可以独立处理分片序列

  • 专家并行(EP):主要作用于 MoE 层,注意力层和其他模块保持不变

  • 流水线并行(PP)和 ZeRO:不局限于特定子模块,但 PP 中需要平衡模块和层的分配,第一层和最后一层通常会被特殊处理(例如额外的 embedding 层)

并行方式
作用对象
通信类型
实现特性
偏好 / 限制

张量 + 序列并行

整个模型的权重和激活

矩阵乘法操作(列/行线性变换)

需要针对具体模型实现

偏好高速节点内通信(高带宽)

上下文并行

注意力层的激活

注意力键/值的跨序列通信

除注意力外为模型无关

偏好超长序列训练

专家并行

MoE 层的权重和激活

token 路由到专家的通信

除 MoE 层外为模型无关

需要 MoE 层

总结

  • 单一 Transformer Block(MoE 变体)的激活和模块图示

image.png

  • 每种策略的内存节省情况

5Dparallelism_8Bmemoryusage.svg

  • 各种并行方法的对比

    方法
    显存节省作用对象
    并行 / 切分维度
    缺点

    数据并行(DP)

    激活值(通过减小本地 batch size)

    Batch

    受最大 batch size 限制

    流水线并行(PP)

    模型参数

    模型层

    空泡(idle bubble)问题,调度复杂

    张量 + 序列并行(TP+SP)

    模型参数和激活

    隐藏维度 / 序列长度

    需要高速通信带宽

    上下文并行(CP)

    激活值

    序列长度

    注意力模块增加通信开销

    专家并行(EP)

    MoE 层参数

    专家维度

    需要 MoE 层,增加 token 路由通信开销

    ZeRO-1

    优化器状态

    在 DP 副本间切分

    参数通信开销

    ZeRO-2

    优化器状态和梯度

    在 DP 副本间切分

    参数通信开销

    ZeRO-3

    优化器状态、梯度和模型参数

    在 DP 副本间切分

    参数通信开销

如何找到最佳的训练配置

  • 总结来说,寻找最佳训练配置的关键思路是:

    • 先保证单个模型实例在 GPU 上能跑

    • 调整全局 batch size 达到训练目标

    • 优化吞吐量,合理组合 TP / DP / PP / CP / EP / ZeRO 策略

步骤 1:让训练步能在显存中跑得下

  • 首先要确保单个模型实例能放入 GPU 显存情况,大致分两类:

  • GPU 资源充足(GPU-rich)

    • 对于 <10B 参数模型:单一并行策略即可,例如张量并行(TP)或者 ZeRO-3 / 数据并行(DP)+ 全量重算,使用 8 张 GPU

    • 对于 10B–100B 参数模型、需要 >8 张 GPU,可选方案包括:

      • TP=8 + 流水线并行(PP)

      • TP=8 + 数据并行(ZeRO-3)

      • 仅 ZeRO-3(纯数据并行)

    • 512+ GPU 规模:

      • 纯 DP / ZeRO-3 开始因通信开销而效率下降

      • 更优方案:DP + TP 或 PP 组合

    • 1024+ GPU 规模:

      • 推荐组合:TP=8 + DP(ZeRO-2)+ PP

    • 注意:这里的重点是单模型实例能跑起来,使用 ZeRO-3 帮助节省模型参数显存

    • 特殊情况:

      • 超长序列 → 可能需要跨节点使用 CP

      • Mixture of Experts(MoE)架构 → 可能需要跨节点使用 EP

  • GPU 资源有限(GPU-poor)

    • 可开启全激活重算(trade compute for memory)

    • 可增加梯度累积步数,在显存有限情况下处理更大 batch

步骤 2:达到目标全局 batch size

  • 在步骤 1 后,得到了一个可训练的微批量(micro-batch)配置,但可能离目标全局 batch size 还差距很大

  • 增大全局 batch size

    • 增加数据并行(DP)或梯度累积步数

    • 超长序列时,可利用 CP 增大 batch size

  • 减小全局 batch size

    • 减少数据并行,转向其他并行策略

    • 超长序列时,可减少 CP 使用

步骤 3:优化训练吞吐量(Throughput)

  • 现在模型已经能跑在目标模型大小和 batch size 上,但希望训练尽可能快,让 GPU 高效利用,可尝试的优化方法:

    • 扩展张量并行(TP)

      • 利用高速节点内带宽

      • 尽量达到单节点 GPU 数量

      • 这样可以减少其他并行策略的依赖

    • 增加数据并行 + ZeRO-3

      • 保持目标 batch size

      • 如果 DP 通信成为瓶颈 → 转向流水线并行(PP)

    • 逐个尝试扩展不同并行策略

      • 找到最优组合

    • 调优微批量大小($mbs$)

      • 在全局 batch size、模型规模、计算量、通信开销之间找到最佳平衡

基准测试

image.png

  • 在 Nanotron 仓库中,可以找到多个脚本,用于运行之前讨论的所有实验,并对自己的模型和集群进行基准测试

  • 基准测试的设置

    • 序列长度:4,096

    • 全局 batch size:1M tokens

  • 热力图中展示了每个模型和集群规模下的最佳配置:

    • 每个组合显示:DP、TP、PP、梯度累积步数(GAS)、微批量大小(MBS)、ZeRO 优化阶段

    • 颜色亮度表示模型 FLOPs 利用率(MFU),越亮表示效率越高

  • 节点数增加 → 效率下降

    • 对小模型尤为明显,因为计算量与模型规模比低

    • 虽然可以通过增加 batch size 来补偿,但受全局 batch size 限制(1M)

  • 大模型面临显存压力

    • 模型规模增大 → 显存需求显著增加

    • 节点数少时:模型可能无法放入显存,或者运行效率低下,因为显存几乎满载

    • 例如,80B 参数模型在 4 节点上训练时就属于此类情况

  • 性能依赖实现质量

    • 初期:TP 优于 PP

    • 优化 PP 实现后:PP 性能更好

    • 目前优化 TP 的通信重叠后,预计 TP 会再次领先

基准测试中的经验教训

  • 基准测试的目标不仅是讨论理论和实现,更是提供实际测试数据。流程大致如下:

    • 针对每个模型和多个集群规模,尝试所有可能的分布式配置

    • 排除不可行配置后,仍需运行数千次实验

  • 实际执行中遇到的挑战:

    • PyTorch 进程有时无法正确清理

    • Slurm 作业管理器会强制终止作业,导致节点失败

    • 简单基准测试可能从几分钟变成几小时

    • 一些作业可能无限挂起

  • 为保证有限时间内完成实验,实验者做了额外工程工作:

    • 最小化集群重启时间,优化空闲时间

    • 分析 NCCL 调试日志

    • 理解显存使用模式和 CUDA 内存分配器行为

    • 改进多节点流水线并行性能

  • 这些经历说明:

    • 理论看似简单,但在实践中涉及许多细节

    • 实际复现理论结果需要高度关注硬件和软件环境

    • 分布式训练不仅是算法问题,也涉及工程能力

FlashAttention

  • 到目前为止的并行运算技术往往依赖一个关键假设:计算和通信可以在 GPU 上高效重叠,而不会影响计算吞吐量,但实际情况要复杂得多

  • 使用常见的通信原语(如 NCCL 的 send/recv)时,计算和通信资源之间存在潜在的竞争,因为通信内核通常会使用与计算相同的 GPU 流式多处理器(Streaming Multiprocessors)

  • 当通信与计算重叠时,这会导致吞吐量下降,要真正优化分布式训练,需要更深入地了解 GPU 的架构细节,可以详见 PyTorch 版块中的 CUDA 章节

  • 基本思想

    • FlashAttention 由 Tri Daoarrow-up-right 提出,目标是通过手写 CUDA 内核,让 attention 的计算更快、更省显存、更贴合 GPU 的内存与计算结构

    • FlashAttention 的核心思想是充分利用 GPU 中更快的内存,尽量避免使用最慢的那一层内存 —— 全局内存

  • 现代 GPU 的全局内存通常使用 HBM(High Bandwidth Memory)

    • HBM 带宽高,但仍然远慢于 GPU 内部的 SRAM、寄存器、共享内存、L1 cache

    • HBM 这个名字容易误导,High Bandwidth ≠ Low Latency,和片上 SRAM 相比,HBM 仍然是“慢内存”

  • Naive Attention 的内存灾难

    image.png - 标准 attention 的计算流程是: - 给定 $Q \in \mathbb{R}^{L \times d}$、$K \in \mathbb{R}^{L \times d}$、$V \in \mathbb{R}^{L \times d}$ - 计算注意力分数 $S = Q K^T \in \mathbb{R}^{L \times L}$ - softmax 归一化 $P = \text{softmax}(S)$ - 得到输出 $O = P V$ - 在 naive 实现中,$S$ 必须完整算出来,$P$ 也必须完整算出来,它们都被存进 HBM - 算完 $QK^T$ → 写回 HBM - 再从 HBM 读 $S$ → 算 softmax → 写回 HBM - 再从 HBM 读 $P$ → 算 $PV$ - 这在长上下文下是灾难性的,$S$ 和 $P$ 的大小是 $L^2$,它们是整个模型中最大的 activation 之一,带宽和显存同时爆炸

  • 在注意力计算的过程中,带宽瓶颈的本质是 attention 在算术上是大量矩阵乘法,计算密度并不低,但性能被严重限制在 HBM 带宽而非算力上

  • FlashAttention 的核心思想是直接在 GPU 的 SRAM(共享内存+寄存器)里分块计算、边算边归一化、边算边累加 $O$​

    • 传统做法:QK^T → 写 S → 读 S → softmax → 写 P → 读 P → PV

    • FlashAttention 的做法:Q_tile × K_tile^T → 在线 softmax(只保留 max / sum) → 直接更新 O_tile

    • $S$ 从头到尾从未完整存在过,只保留 softmax 所需的当前最大值、归一化因子(sum of exp)

    • 这样,不仅利用共享内存解决了 HBM 的带宽瓶颈,还释放了注意力矩阵的显存占用

image.png

  • softmax 有一个重要性质:

    softmax(xi)=eximjexjm\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}
    • 只要维护当前块的最大值和累计的指数

    • 就可以一块一块处理 $S$,最终得到和一次性 softmax 完全一样的结果

  • 具体步骤

    • 分块:不一次性算完整个 $S$,而是把 $QK^T$ 分成小块,每一小块能放进共享内存

    • online softmax:利用 softmax 的重要性质,每个 block 维护当前最大值 $m$ 和累计归一化因子 $l$

    • 这样可以做到一边算 $QK^T$,一边算 softmax,一边直接乘 $V$,一边累加 $O$

    • 最终效果是:$S$ 从未完整存在过、$P$ 从未完整存在过,中间结果始终留在寄存器/共享内存,只在最后一次性把 $O$ 写回 HBM

  • FlashAttention 的变革

    • 在 FlashAttention 出现之前,大量工作集中在 linear attention、sub-quadratic attention 等各种近似 softmax

    • 它们的动机是 $O(L^2)$ 的算法复杂度无法忍受,但 FlashAttention 证明了问题不在算法复杂度,而是在内存访问模式

    • 于是精确的 attention + 正确的 kernel 设计,反而比近似方法更快、更省显存

  • 后续优化

    • FlashAttention-2:减少非 GEMM 操作,更精细地划分 warp、划分 thread block、更好地利用 GPU 的执行模型

    • FlashAttention-3:专门针对最新 Hopper(H100)架构,支持 FP8,深度使用 Tensor Core,进一步压榨算力与带宽

    • FlashAttention 对 attention pattern 有一定限制,不适用于所有稀疏或自定义 attention 形式,因此出现了 FlexAttentionarrow-up-right,在保持高性能的同时支持更灵活的 attention 结构

混合精度训练

基本思想

  • 在前面的章节在,已经看到过低精度训练的应用

    • 显存瓶颈并不只来自模型参数,激活和优化器状态往往更大

    • 低精度的核心价值就在于用更少的 bit 表示数值,换取显著的内存节省和更高的吞吐

    • 但这一定伴随着数值精度和稳定性的代价

  • 混合精度训练(mixed precision training)

    • 顾名思义,是指在训练过程中混合使用多种数值精度

    • 在 PyTorch 中,张量的默认数值精度是单精度浮点数,也称为 FP32 或 float32,这意味着每个数占用 32 位(bit),即 4 个字节(byte)

    • “混合”这个词非常重要,它并不是说全部都用低精度,而是有些张量用 FP32,有些用 FP16 / BF16 / FP8;有些计算在低精度做,但结果保存在高精度里,这是后面能保证稳定训练的关键

  • 浮点数的表示——用于表示一个浮点数的比特被分成三部分:

    sign-mantissa-exponent.svg - 浮点数并不是“定点表示”,而是用有限 bit 去近似表示一个连续实数空间 - 符号位(Sign):第一个 bit 决定这个数是正数还是负数,决定连续实数空间的反向 - 指数位(Exponent):控制数值的量级(范围),决定能表示多大或多小 - 尾数位(Mantissa,也叫有效数):决定数值的精度(有效数字),决定表示得有多细

    • 浮点数的原理可以通过科学计数法来直观理解,例如:$-5.734 \times 10^7$

    • 在这种表示中,先有符号(负号),接着是尾数(5.734),最后是指数($10^7$)

    • 计算机里的浮点数本质上就是科学计数法的二进制版本:

      x=(1)sign×mantissa×2exponentx = (-1)^{sign} \times mantissa \times 2^{exponent}
    • 区别只在于科学计数法是以 10 为底,浮点数是以 2 为底

  • PyTorch 提供的浮点数格式

    格式
    比特数
    符号位比特数
    指数位比特数
    尾数位比特数

    float32

    32

    1

    8

    23

    float16

    16

    1

    5

    10

    bfloat16

    16

    1

    8

    7

    float8 (e4m3)

    8

    1

    4

    3

    float8 (e5m2)

    8

    1

    5

    2

  • 范围和精度的取舍

    • 减少总 bit 数一定是有代价的,但可以选择代价付在哪里——要么减少尾数的 bit,要么减少指数的 bit

    • 浮点数的核心矛盾就是 exponent 决定“能表示多大/多小”(数值范围),而mantissa 决定“能表示多细”(数值精度),bit 总数有限,只能在范围和精度之间做取舍

    • 正因为这种取舍关系,FP8 并不是只有一种格式,而是有两种 float8 格式,根据指数位和尾数位命名

  • 不同数值的表示范围

    image.png - float32 可以覆盖约 80 个数量级 - float16 为了节省 bit,牺牲了大量数值范围 - bfloat16 则保留了与 float32 相同的数值范围, - float16 和 bfloat16 都是 16 bit,但 float16:指数少 → 容易 overflow / underflow,而 bfloat16:指数多 → 范围几乎和 float32 一样 - 对于 float8,e5m2 可以勉强维持 float16 的数值范围,而 e4m3 数值范围进一步缩小

  • 不同数值的分辨率(resolution)

    image.png - 在区间 $[1, 2]$​ 内取 10,000 个点,将这些点分别四舍五入到各个格式中最接近的可表示数 - 从图中可以看到,bfloat16 虽然维持了 float32 的数值范围,但代价是精度明显下降,可表示的数变得更稀疏 - 这正是 bfloat16 的设计哲学,宁愿数值表示得“粗一点”,也要避免数值溢出,因为对深度学习来说,梯度爆炸/消失比“多一点噪声”更致命 - 在 float8 的情况下,问题更加严重 - e4m3 在区间 $[1,2]$ 内只能表示 7 个数 - e5m2 甚至只能表示 3 个数

    • 这意味着在这个区间内大量不同的实数会被量化成同一个值,这对梯度、激活、权重更新都是极大的挑战

    • 衡量一个浮点格式分辨率的常用指标是 epsilon,即大于 1.00 的最小可表示数

      • epsilon 越小,表示在 1 附近能表示得越精细,数值误差越小

      • float32 的 epsilon 上界约为 $10^{-4}$,实际值大约是 $1.19 \times 10^{-7}$

      • float16 的 epsilon 约为 $10^{-3}$,bfloat16 的 epsilon 还要再大 10 倍

  • 混合精度训练的核心思想

    • 在某些计算中使用低精度格式,同时保持与全精度训练几乎一致的效果和稳定性

    • 关键不是“用不用低精度”,而是哪些地方可以用,哪些地方绝对不能用

    • 事实证明,无法彻底抛弃 float32,通常仍然需要在某些计算中使用全精度,如梯度累积、权重更新、loss scaling 的关键路径,否则训练会发散

FP16 / BF16

  • 最早的混合精度训练论文:Mixed Precision Trainingarrow-up-right

  • FP32 权重副本:

    • 使用 FP16 权重时,在训练过程中,一些权重可能会变得非常小,从而被直接舍入为 0

    • 即便权重本身不接近 0,如果梯度更新量非常小,在做加法时,由于数量级差异,也可能发生下溢

    • 更新规则本质是 $w \leftarrow w + \Delta w$,当 $|w| \gg |\Delta w|$ 时,在 FP16 中,$\Delta w$​ 可能直接被忽略

    • 一旦权重变成 0,梯度传播中断,该参数彻底失效

    • 因此,混合精度训练中通常会维护一份 FP32 权重副本,用于存储和更新权重

    • 前向传播时,FP32 权重会被转换为 FP16,用于计算

    • 反向传播时,计算得到的 FP16 梯度会被转换为 FP32,用于更新 FP32 权重

    • 这样可以确保权重更新的精度和范围,避免下溢问题

  • Loss scaling:

    • 梯度同样存在类似问题,因为梯度值通常远小于 1,因此极易发生下溢

    • 在深度网络中,梯度往往是 $10^{-3}$、$10^{-5}$​ 甚至更小,FP16 对这种数值极其不友好

    • 一个简单但非常有效的策略是在反向传播前放大 loss,在反向传播后再把梯度缩小回来

    • 这样可以保证反向传播过程中不会发生下溢,而且由于在后续处理(如梯度裁剪)和参数更新前会把梯度缩放回来,因此不会影响训练结果

    • 数值上,backward 用的是放大后的梯度,而语义上,optimizer 看到的仍是真实梯度

  • 累加(Accumulation):

    • 在 16-bit 精度下进行某些算术操作(例如求和、求平均)时,也可能发生下溢或上溢,如多个梯度求和、moving average、batch 内统计量

    • 解决方法是,在运算过程中使用 FP32 累加中间结果,最终结果再转换回 16-bit 精度

    • 即 compute 用 FP16 / BF16,accumulate 用 FP32

  • 通过上述技术,混合精度训练能够在保持训练稳定性的同时,大幅提升计算速度和显存利用率

FP8

  • 基本思想

    • 即便能够把通信与计算做到完美重叠,最终仍然会碰到一个无法回避的瓶颈:硬件本身的理论 FLOPS 上限,也就是单个算子在硬件上的执行效率

    • 前面所有优化(pipeline、overlap、fusion)解决的都是等待的问题,内存如何解决的问题,但现在的瓶颈是单次乘加本身有多快

    • 例如,在 NVIDIA H100 GPU 上 FP8 的矩阵乘法(GEMM)的理论 FLOPS 是 BF16 的 2 倍,这使得更低精度的训练成为进一步提升性能的潜在方向

    • 近期的研究工作,包括 FP8-LM、torchao 和 DeepSeek-V3,已经展示了 FP8 在大规模模型训练中的潜力

  • 然而,FP8 预训练带来了一个极其严峻的挑战:数值稳定性

    fp8-loss - 在更低精度下,数值不稳定性常常导致 loss 发散,使得训练效果难以达到高精度训练的水平 - 在模型规模固定的情况下,学习率越大,训练越不稳定,大模型预训练通常需要较大的 learning rate,而 FP8 对 learning rate 极其敏感

  • 首次公开成功的大规模 FP8 混合精度训练来自 DeepSeek-V3 的技术报告

    image.png - 作者对以下所有阶段的每一个算子都进行了细致分析: - 前向传播(Fprop) - 激活反向(Dgrad) - 权重反向(Wgrad) - 类似于 BF16 混合精度训练,一些聚合操作以及 master 权重仍然保留在更高精度中,而具体算子本身使用 FP8 - 这延续在 FP16/BF16 中已经见过的原则:compute 可以低精度、state 和 accumulation 必须高精度 - 为了从高精度(如 FP32、BF16)切换到数值范围更小的低精度(如 FP16、FP8),需要对激活值的范围进行归一化,例如计算其绝对最大值,这是因为FP8 表示范围极小,不缩放就必然溢出或下溢 - DeepSeek-V3 进一步提出了一种按 tile 进行归一化的量化方案: - 输入 / 激活:$1 \times 128$ - 权重与 scale:$128 \times 128$​ - 这种做法可以显著降低激活中异常值(outlier)对归一化范围的影响,如果全局 max → 极易被一个 outlier 毁掉,而使用 tile-wise max → 局部自适应缩放

  • 常见的 FP8 策略

    GEMM 精度方案
    主模型权重
    梯度累积
    模型权重
    梯度
    优化器状态
    总内存表示
    单参数内存占用

    BF16 + FP32 混合精度基线

    BF16

    FP32

    FP32

    BF16

    BF16

    FP32 + FP32

    4 + 4 + 2 + 2 + 4 + 4 = 20 字节

    去掉 FP32 梯度累积的上述方案

    BF16

    FP32

    不需要

    BF16

    BF16

    FP32 + FP32

    4 + 2 + 2 + 4 + 4 = 16 字节(减少 20%)

    Transformer Engine

    FP8

    不需要

    不需要

    FP32

    FP32

    FP32 + FP32

    4 + 4 + 4 + 4 = 16 字节(减少 20%)

    FP8-LM 的 O3 级别

    FP8

    FP16

    FP16

    FP8

    FP8

    FP8 + FP16

    2 + 2 + 1 + 1 + 1 + 2 = 9 字节(减少 55%)

    DeepSeek-V3

    FP8

    FP32

    FP32

    FP8

    BF16

    BF16 + BF16

    4 + 4 + 1 + 2 + 2 + 2 = 15 字节(减少 25%)

    Nanotron 的 FP8

    FP8

    BF16

    FP32

    FP8

    FP8

    FP8 + FP8

    2 + 4 + 1 + 1 + 1 + 1 = 10 字节(减少 50%)

  • 总体而言,截至 2025 年初,FP8 仍然是一种实验性技术

    • 这里的“实验性”不是指“不可用”,而是指:

      • 没有形成单一、稳定、被广泛验证的“标准 FP8 训练范式”

      • 不同团队在 scale、quantization、accumulation、optimizer 精度上仍有明显分歧

    • 和 BF16 的成熟度相比,BF16 几乎“无脑可用”,而 FP8 必须理解数值细节才能用好

分布式训练总结

image.png

LLM 参数量计算

  • 在讨论显存或计算量时,通常统计的是“元素数”(elements),可以把它理解为张量中的数值个数

  • 要得到真实的内存占用(字节数),需要再乘以每个数值的字节大小,例如 BF16 是 2 字节,FP32 是 4 字节

  • 输入 token 数:对于每个 micro-batch,处理的 token 数为 $\text{seq} \cdot \text{mbs}$,其中 $\text{mbs}$ 是 micro-batch size,$\text{seq}$ 是序列长度

  • 激活值(隐藏状态):对于单个 Transformer 层,隐藏状态张量的大小为 $\text{seq} \cdot \text{mbs} \cdot h$ 个元素,其中 $h$ 是 hidden size

  • 模型权重与梯度:模型中的每个权重矩阵(例如线性层)大约包含 $h^2$ 个元素,对应的梯度张量大小与权重完全相同

  • 优化器状态:对于每一个大小为 $h^2$ 的权重矩阵,使用 Adam 这类优化器并配合混合精度训练时,通常需要维护:

    • 一阶动量(FP32)

    • 二阶动量(FP32)

    • 共 $2 \times 2h^2$ 个 FP32 元素

    • 以及 FP32 的 master weights:$2h^2$ 个元素

    • 因此,每个权重矩阵对应的优化器状态总量大约为 $6h^2$ 个元素

  • Transformer 单层的参数量

    • 注意力部分参数:

      • QKV 投影:$3h^2$

      • 输出投影:$h^2$

    • MLP(采用 GLU 结构)参数:

      • gate 与 up 投影:$8h^2$(两组 $h \times 4h$ 的矩阵)

      • down 投影:$4h^2$(一组 $4h \times h$ 的矩阵)

    • 使用 GLU 的单层 Transformer,总参数量约为 $16h^2$

    • 不使用 GLU 的情况下,总参数量约为 $12h^2$

  • 全模型参数量:若模型包含 $\text{num_layers}$ 层 Transformer(使用 GLU),则参数总量约为 $16h^2 \cdot \text{num_layers}$

  • 额外参数

    • 输入词嵌入:$\text{vocab_size} \cdot h$

    • LM head(若未与输入嵌入权重共享):$\text{vocab_size} \cdot h$

    • 位置嵌入(若使用):$\text{max_seq_len} \cdot h$

  • 前向与反向计算量(FLOPs),一个非常粗略的估计是:

    • 前向传播 FLOPs:$2 \cdot \text{num_tokens} \cdot \text{num_params}$

    • 反向传播 FLOPs(约为前向的 2 倍):$4 \cdot \text{num_tokens} \cdot \text{num_params}$

  • 更精确的前向 + 反向 FLOPs 公式为 $6 \cdot \text{seq_len} \cdot \text{num_params} + 12 \cdot \text{num_layers} \cdot h \cdot \text{seq_len}^2$

    • 其中第二项刻画了自注意力在整个序列维度上的二次复杂度。为了简化分析,通常假设 $\text{seq_len}^2 \ll h$

Last updated

Was this helpful?