6. 标准训练流程

标准训练流程

  • 标准训练流程 = 训练/验证模式切换 → 取 batch → 前向 → loss → 反向 → 更新 → 清理状态 → 记录

  • 最外层结构(epoch 级)

    # epoch 是数据遍历次数
    # Dataset 不知道 eposh
    # Sampler(DDP)时需要知道 epoch
    for epoch in range(num_epochs):
        if sampler is not None:
            sampler.set_epoch(epoch)
    
        train_one_epoch(...)
        val_loss = validate(...)
    
        if rank == 0:
            print(f"epoch {epoch}, val_loss={val_loss:.4f}")
  • train_one_epoch

    def move_to_device(batch, device):
        if isinstance(batch, torch.Tensor):
            return batch.to(device, non_blocking=True)
        if isinstance(batch, dict):
            return {k: move_to_device(v, device) for k, v in batch.items()}
        if isinstance(batch, list):
            return [move_to_device(v, device) for v in batch]
        return batch
    
    def train_one_epoch(
        model,
        dataloader,
        optimizer,
        device,
        scheduler=None
    ):
        model.train() # 切换到训练模式
    
        # Sampler 已决定 index、DataLoader 已构造 batch、collate_fn 已对齐结构
        for batch in dataloader:
            batch = move_to_device(batch, device)
    
            outputs = model(**batch["inputs"])
            # loss 必须是标量,loss 必须在当前 device,loss 必须有 grad
            loss = outputs["loss"]
    		# autograd 从 loss 出发,沿计算图反向传播,每个参数的 grad 被累积
            loss.backward()
            # 使用当前 grad,更新模型的参数
            optimizer.step()
            # 在更新之后,清空梯度历史,防止下一个 batch 梯度叠加
            optimizer.zero_grad()
    
            if scheduler is not None:
                # 学习率的时间轴,可能会是 epoch 级别的更新,即
                # for epoch:
                #     train_one_epoch()
                #     scheduler.step()
                scheduler.step()
  • validate

    def validate(model, dataloader, device):
        model.eval() # 切换至验证模式
        total_loss = 0
    
        with torch.no_grad(): # 无需计算梯度
            for batch in dataloader:
                batch = move_to_device(batch, device)
                outputs = model(**batch["inputs"])
                total_loss += outputs["loss"].item()
    
        return total_loss / len(dataloader)
  • 在一个完整的训练流程中,不应该出现 Dataset 逻辑、DataLoader 配置、Model 定义、分布式初始化

训练流程增强

  • 增强后的训练流程

  • 增强后的验证流程

AMP

  • AMP(Automatic Mixed Precision) 要实现的是:用更低精度算得更快,但又不把数值稳定性搞崩

  • 深度学习里的计算任务可以分为两类

    • 计算密集型:matmul、conv、attention

    • 数值敏感型:loss、梯度累积、optimizer 更新

  • AMP 的核心思想是不同类型的算子,使用不同精度

  • fp16 的致命缺陷是表示范围小、精度低,在反向传播中梯度往往非常小($10^{-8}$、$10^{-10}$),而 fp16 表示不了 → underflow → 梯度直接变 0,从而导致梯度消失,模型无法学习

  • AMP 由两个组件构成

    • autocast → 控制 forward 的计算精度,只影响 forward,不改变模型参数的存储精度(仍是 fp32),而是在 forward 时临时切换算子精度

    • GradScaler → 控制 backward 的数值稳定性,解决梯度 underflow,其先把 loss 放大,再反向传播,反向传播后再把梯度缩小

  • AMP 的完整顺序

  • AMP 并不会减少显存里的参数大小,其参数仍是 fp32,而显存节省主要来自激活/中间结果

  • AMP 不保证一定更快,当 batch 很小或者 CPU/IO 瓶颈时收益有限

  • 在 loss 本身极不稳定或自定义 CUDA op 不支持 fp16 时,不能使用 AMP

  • AMP + DDP 为什么不会冲突

    • DDP 的 AllReduce 用的是 parameter.grad,而 AMP 在 AllReduce 之前已经把它 unscale 成 fp32

    • 一轮 iteration 的完整顺序

    • DDP 同步发生在 backward 结束之后,且看到的是已经 unscale 的梯度

    • 在 PyTorch 中,DDP 会在 backward 时给每个参数注册一个 autograd hook,当梯度计算完成时,hook 触发,对该参数的梯度做 AllReduce

    • 但这个 hook 看到的是 grad 张量本身,而 AMP 的流程是 backward 结束、scaler.unscale_() 改写 grad 的数值,此时 DDP hook 同步的是 unscale 后的结果

  • GradScaler 不会影响多卡一致性,因为所有 rank 用的是同一个 loss scale 策略,如果某个 rank 出现 inf / nan,scaler.step() 会在所有 rank 跳过更新

  • 推理阶段的混合精度与训练阶段的 AMP

    • 训练的 AMP 是“数值安全工程”,而推理的混合精度是“纯性能工程”

    • 推理阶段的流程

    • 推理阶段无需计算梯度,数值问题主要来自 softmax / normalization,而这些算子会由 autocast 保留为 fp32,因此不需要 GradScaler

    • 对比

学习率调度(Scheduler)

  • Step / Cosine

  • Warmup 是 scheduler 的一部分,不应散落在训练循环各处

Checkpoint

  • 保存 checkpoint (只在 rank 0):保存参数状态、优化器动量、scheduler 位置、AMP 缩放状态、当前 epoch

  • 恢复 checkpoint

梯度裁剪

  • 位置放在 backward 之后,optimizer.step 之前

  • clip 是训练稳定性工具,不改变 loss,只限制更新幅度

早停(Early Stopping)

  • 早停不在 train_one_epoch 里,而在 epoch 外

完整示例

  • 一轮训练 iteration 中,真实发生的顺序是:

    • Dataset 定义「样本是什么」

    • DataLoader 定义「batch 如何产生」

    • DistributedSampler 决定「每张卡拿哪些样本」

    • model = DDP(model)

    • forward 在 autocast 里

    • loss 计算(fp32)

    • scaler.scale(loss).backward()

    • scaler.unscale_(optimizer)

    • 可选:gradient clipping

    • optimizer.step()

  • 第一步,实现 Dataset

  • 第二步,初始化分布式

  • 第三步,实现 DataLoader

  • 第四步,实现模型

  • 第五步,初始化相关组件

  • 第六步,实现训练流程

  • 第七步,实现验证流程

  • 第八步,完整代码流程

Last updated

Was this helpful?