5. 数据集与加载器

Dataset

  • Dataset 是一个“样本索引 → 单条样本数据”的映射

  • 它只回答一个问题:给定第 i 个样本,应该返回什么?

  • Dataset 只负责三件事

    • 定位样本

      self.samples[i]
    • 读取样本:从磁盘/内存/网络读取样本

    • 返回模型可用的最小结构,如 Tensor / 数值 / 字典,返回的是单条样本,而不是 batch

      return {
          "input": x,
          "target": y
      }
  • Dataset 不应该做的事

    • 不做 batch:batch 是 DataLoader 的职责

      # 错误示例
      def __getitem__(self, idx):
          return batch_x, batch_y
    • 不做 shuffle:shuffle 是采样策略,应由 sampler 控制,多 worker 时不可控

      random.shuffle(self.samples)
    • 不放 GPU:Dataset 在 worker 进程中,而 GPU 只能在主进程 / 训练循环,会导致显存泄漏或 crash

      x = x.to("cuda")   # 错误
    • 不依赖全局可变状态:worker 是多进程,状态不同步,行为不可复现

      global counter
      counter += 1
  • Dataset 的组成结构

    class MyDataset(torch.utils.data.Dataset):
        def __init__(self, data_list):
            self.data = data_list
    
        # Dataset 的逻辑长度,sampler / DataLoader 都会依赖它
        # 不等于 epoch 样本数,分布式会自动切分数据集
        def __len__(self):
            return len(self.data)
    	
        # 给定 idx,返回一条确定的样本,Dataset 应该是幂等的
        def __getitem__(self, idx):
            item = self.data[idx]
            x = item["x"]
            y = item["y"]
            return x, y
  • Dataset 输出的常见风格

    • Tuple:适合单任务

      return x, y
    • Dict:模型接口稳定,容易扩展

      return {
          "input_ids": x,
          "labels": y,
          "attention_mask": mask
      }
    • NamedTuple / dataclass:更强的类型约束,适合大型工程

  • Map-style Dataset(标准 Dataset)

    class MyDataset(torch.utils.data.Dataset):
        def __len__(self):
            return N
    
        def __getitem__(self, idx):
            return sample
  • 对于 Map-style Dataset,因为能 index,PyTorch 可以:

    • shuffle(通过 sampler)

    • 切分数据(train / val)

    • 分布式切 shard

    • resume(按 index 恢复)

    • random access augmentation

  • 此时,Map-style Dataset 对应的 DataLoader 的行为模型如下

    sampler → indices → __getitem__ → collate → batch
  • Iterable-style Dataset(IterableDataset)

    class MyIterableDataset(torch.utils.data.IterableDataset):
        def __iter__(self):
            for sample in stream:
                yield sample
  • Iterable-style Dataset 只有 __iter__,没有 __len__,无法随机访问

    • 这是由于有些数据没有第 i 条或者代价巨大

    • 比如日志流、无限数据生成器、超大数据、Kafka / WebDataset

  • 此时,Iterable-style Dataset 对应的 DataLoader 的行为模型如下

    # 没有 sampler
    __iter__ → sample → collate → batch
  • 如果必须使用 Iterable-style Dataset,就必须自己承担责任

    # 多 worker 下手动切分,否则每个 worker 读全量数据,会造成数据重复
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
    
        if worker_info is None:
            start = 0
            step = 1
        else:
            start = worker_info.id
            step = worker_info.num_workers
    
        for i, sample in enumerate(self.stream()):
            if i % step == start:
                yield sample
    # 分布式训练(数据并行)还需要再进行切分,需要额外考虑 rank、world_size

DataLoader

  • DataLoader 本质上是一个样本调度 + 批构造 + 并行加载的编排流程

  • DataLoader 不考虑数据语义,只做四件事:

    • 决定“按什么顺序”取样本

    • 决定“一次取多少”

    • 决定“谁来取”(主进程 / worker)

    • 决定“怎么拼成 batch”

  • 最常见的 Map-style Dataset 的迭代流程

  • DataLoader 示例

  • 如果指定 shuffle=True,DataLoader 内部将使用 RandomSampler,sampler 生成索引顺序,等价于

  • batch_size=N 指得是 sampler 生成 N 个 index,依次调用 N 次 dataset[i],得到 N 个 sample,交给 collate_fn,然后 collate_fn 返回 batch

  • 由此可见,sampler 决定顺序,batch_sampler 决定分组

  • 常见 sampler

    • SequentialSampler(验证 / 测试)

    • RandomSampler(训练)

    • DistributedSampler(分布式)

  • num_workers

    • num_workers = 0:所有 dataset[i] 在主进程,容易调试,但读取数据集太慢

    • num_workers > 0:fork / spawn 出 worker 进程,每个 worker 各自持有一份 Dataset,worker 并行调用 __getitem__

    • 注意,Dataset 会被 pickle,因此不能依赖共享状态,不能在 Dataset 里存 socket / file handler(除非手动管理)

  • DataLoader 不该处理的事

    • 不该理解样本结构

    • 不该做 padding

    • 不该做 tokenizer

    • 不该做 augmentation

  • 这些事情应该交由 Dataset 或 collate_fn 处理

collate_fn

  • 在上面的例子中,DataLoader 取出数据之后,会得到 4 个 Dataset 返回的 dict

  • 而默认的 collate_fn 会将其处理为

  • collate_fn 实际上是在保证输出结构是规整的,这也是为什么 Dataset 最好返回 dict

    • tuple → zip + stack

    • list → zip + stack

    • dict → 按 key stack

  • 在工程上可以用如下方式调用

  • 那么如何自定义 collate_fn,实际上,collate_fn 接收的东西只有一个 batch 的 sample 列表

  • 所以,Dataset → 定义单样本结构,而 collate_fn → 定义 batch 结构

  • 最小自定义 collate_fn

  • 示例:返回变长序列

  • Padding 到 batch 内最大长度

  • padding 一定在 collate_fn 中,而不是 Dataset,因为 Dataset 不知道 batch 的上下文,而 padding 是 batch 级别的操作

  • 在 collate_fn 中不能使用 GPU 操作,因为 collate 是在 worker 进程中的,而 CUDA 不能跨进程安全使用,应该在训练循环中 .to(device)

sampler

  • Sampler 是一个 index 生成器:

  • 常见 sampler 对应关系

    • shuffle=False → SequentialSampler

    • shuffle=True → RandomSampler

    • DDP → DistributedSampler

  • 手动指定 sampler

  • DDP(Distributed Data Parallel):在数据并行中,每张 GPU = 一个独立 Python 进程,每个进程都有一份 Dataset、一份 DataLoader、一份 Sampler

    • 在 DDP 里,一个参与者 = 一个进程 = 一张 GPU

    • world_size: 参与分布式训练的进程总数

    • rank 是这个进程在整个并行世界里的 ID

  • 而 DistributedSampler 中,给定 dataset size = $N$,world_size = $W$,rank = $r$,它会把 index 划分为

  • 每个样本只会被一个进程看到,所有进程加起来 = 完整 dataset

  • 在 DDP 中,shuffle 必须由 DistributedSampler 自己做 shuffle

  • sampler 在每个 epoch 中 都要重新随机分配

  • 在 DDP 中,假设全局 batch 大小为 256,world_size 为 8,那么每个 DP 设备的 batch 应该是 256/8=32,每个进程的 sampler 只给 1/8 的样本,DataLoader 每次拿 32 个,梯度通过 AllReduce 汇总 —— DDP 不会自动帮你拆 batch

Dataset 与训练循环

  • 一个标准的训练循环

性能问题

  • 数据流程如下

  • DataLoader 性能问题,本质只有一个:GPU 的计算速度比数据处理得快

  • num_workers

    • num_workers = 同时跑 Dataset._getitem_ 的子进程数量

    • Dataset 里通常包含文件读取、解码、tokenize、augmentation,这些 CPU-bound / IO-bound 操作非常适合多进程

    • 本地 SSD + 简单处理,num_workers 一般设置为 4-8

    • 重预处理(tokenize / decode),num_workers 一般设置为 8-16

    • 网络存储,num_workers 一般取决于带宽大小

    • num_workers 不是越大越好,num_workers 越大,CPU 占用越高,导致 worker 队列阻塞,性能反而会下降

  • prefetch_factor

    • 每个 worker 预先准备的 batch 数,默认 prefetch_factor = 2

    • 实际总预取 batch 数:num_workers × prefetch_factor

    • 如果 batch 构造慢,GPU 计算很快,那么 GPU 会等待 batch 的处理,而增加 prefetch 可以减少气泡,提高吞吐的稳定性

    • 在 GPU utilization 抖动或训练时偶尔卡顿时,可以尝试改动 prefetch_factor

  • pin_memory

    • 普通的 CPU 内存可以被操作系统 swap,而 pinned memory 物理页被锁定,不能被操作系统 swap

    • 当指定 pin_memory=True 时,batch 会被拷贝到 pinned CPU memory,.to(device, non_blocking=True) 可以异步进行

    • 因此一般要成对配合,否则 pin_memory 几乎没有收益

    • 当 GPU 训练/ batch size 较大/数据复制的 IO 开销成为瓶颈时,可以考虑启用 pin_memory

    • 当启用 pin_memory 时,一定要配合 non_blocking 使用

  • non_blocking=True

    • tensor.to(device, non_blocking=True) 中,当 tensor 在 pinned memory 且 CUDA steam 可用时,可以将复制与计算重叠,提高 GPU 的利用率

  • persistent_workers

    • 默认情况下,每个 epoch 会先 fork workers,训练后再回收 workers

    • 而 fork 的开销是比较昂贵的,通过保留 workers,可以使 workers 在 epoch 之间复用,大幅降低 epoch 切换开销

    • 当 num_workers 大于 0,且 Dataset 的生命周期大于多个 epoch 时,可以启用该选项

  • 性能友好的 DataLoader 配置

  • 80% 的 GPU 空闲来自 IO、tokenize 等重预处理、Python 逻辑等,可以通过 GPU utilization、torch.profiler、Nsight Systems 等来判断训练缓慢的瓶颈

Last updated

Was this helpful?