5. 数据集与加载器
Dataset
self.samples[i]return { "input": x, "target": y }
# 错误示例 def __getitem__(self, idx): return batch_x, batch_yrandom.shuffle(self.samples)x = x.to("cuda") # 错误global counter counter += 1
class MyDataset(torch.utils.data.Dataset): def __init__(self, data_list): self.data = data_list # Dataset 的逻辑长度,sampler / DataLoader 都会依赖它 # 不等于 epoch 样本数,分布式会自动切分数据集 def __len__(self): return len(self.data) # 给定 idx,返回一条确定的样本,Dataset 应该是幂等的 def __getitem__(self, idx): item = self.data[idx] x = item["x"] y = item["y"] return x, yreturn x, yreturn { "input_ids": x, "labels": y, "attention_mask": mask }
class MyDataset(torch.utils.data.Dataset): def __len__(self): return N def __getitem__(self, idx): return samplesampler → indices → __getitem__ → collate → batchclass MyIterableDataset(torch.utils.data.IterableDataset): def __iter__(self): for sample in stream: yield sample# 没有 sampler __iter__ → sample → collate → batch# 多 worker 下手动切分,否则每个 worker 读全量数据,会造成数据重复 def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: start = 0 step = 1 else: start = worker_info.id step = worker_info.num_workers for i, sample in enumerate(self.stream()): if i % step == start: yield sample # 分布式训练(数据并行)还需要再进行切分,需要额外考虑 rank、world_size
DataLoader
collate_fn
sampler
Dataset 与训练循环
性能问题
Last updated
Was this helpful?