6. 标准训练流程
标准训练流程
# epoch 是数据遍历次数 # Dataset 不知道 eposh # Sampler(DDP)时需要知道 epoch for epoch in range(num_epochs): if sampler is not None: sampler.set_epoch(epoch) train_one_epoch(...) val_loss = validate(...) if rank == 0: print(f"epoch {epoch}, val_loss={val_loss:.4f}")def move_to_device(batch, device): if isinstance(batch, torch.Tensor): return batch.to(device, non_blocking=True) if isinstance(batch, dict): return {k: move_to_device(v, device) for k, v in batch.items()} if isinstance(batch, list): return [move_to_device(v, device) for v in batch] return batch def train_one_epoch( model, dataloader, optimizer, device, scheduler=None ): model.train() # 切换到训练模式 # Sampler 已决定 index、DataLoader 已构造 batch、collate_fn 已对齐结构 for batch in dataloader: batch = move_to_device(batch, device) outputs = model(**batch["inputs"]) # loss 必须是标量,loss 必须在当前 device,loss 必须有 grad loss = outputs["loss"] # autograd 从 loss 出发,沿计算图反向传播,每个参数的 grad 被累积 loss.backward() # 使用当前 grad,更新模型的参数 optimizer.step() # 在更新之后,清空梯度历史,防止下一个 batch 梯度叠加 optimizer.zero_grad() if scheduler is not None: # 学习率的时间轴,可能会是 epoch 级别的更新,即 # for epoch: # train_one_epoch() # scheduler.step() scheduler.step()def validate(model, dataloader, device): model.eval() # 切换至验证模式 total_loss = 0 with torch.no_grad(): # 无需计算梯度 for batch in dataloader: batch = move_to_device(batch, device) outputs = model(**batch["inputs"]) total_loss += outputs["loss"].item() return total_loss / len(dataloader)
训练流程增强
完整示例
Last updated
Was this helpful?