评价此页

torch.utils.data#

创建日期:2025年6月13日 | 最后更新日期:2025年12月16日

PyTorch 数据加载工具的核心是 torch.utils.data.DataLoader 类。它代表了一个数据集上的 Python 可迭代对象,支持:

这些选项通过 DataLoader 的构造函数参数进行配置,其签名为:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

以下章节详细描述了这些选项的效果和用法。

数据集类型#

DataLoader 构造函数最重要的参数是 dataset,它指明了从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:

映射式数据集#

映射式数据集是实现了 __getitem__()__len__() 协议的数据集,代表了从(可能是非整数的)索引/键到数据样本的映射。

例如,当使用 dataset[idx] 访问此类数据集时,它可以从磁盘文件夹中读取第 idx 个图像及其对应的标签。

更多详细信息,请参阅 Dataset

迭代式数据集#

迭代式数据集是 IterableDataset 子类的实例,它实现了 __iter__() 协议,并代表了数据样本上的一个可迭代对象。这种类型的数据集特别适用于随机读取开销很大甚至不可能,以及批大小取决于所获取数据的情况。

例如,当调用 iter(dataset) 时,此类数据集可以返回从数据库、远程服务器甚至实时生成的日志中读取的数据流。

更多详细信息,请参阅 IterableDataset

注意

当在 多进程数据加载 中使用 IterableDataset 时,同一个数据集对象会在每个工作进程(worker process)上进行复制,因此必须对副本进行不同的配置以避免数据重复。请参阅 IterableDataset 文档了解如何实现这一点。

数据加载顺序与 Sampler#

对于 迭代式数据集,数据加载顺序完全由用户定义的可迭代对象控制。这允许更轻松地实现块读取和动态批大小(例如,每次产生一个批处理样本)。

本节其余部分涉及 映射式数据集 的情况。torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键序列。它们代表数据集索引上的可迭代对象。例如,在随机梯度下降(SGD)的常见情况下,Sampler 可以随机排列索引列表并一次产生一个,或者为小批量 SGD 产生少量索引。

顺序采样器或随机采样器将根据 DataLoadershuffle 参数自动构建。或者,用户可以使用 sampler 参数指定一个自定义的 Sampler 对象,该对象每次产生下一个要获取的索引/键。

可以通过 batch_sampler 参数传递一个自定义的 Sampler,该采样器一次产生一批索引列表。自动批处理也可以通过 batch_sizedrop_last 参数启用。有关此内容的更多详细信息,请参阅 下一节

注意

samplerbatch_sampler 都不兼容迭代式数据集,因为此类数据集没有键或索引的概念。

加载批处理与非批处理数据#

DataLoader 支持通过 batch_sizedrop_lastbatch_samplercollate_fn(具有默认函数)参数将获取的各个数据样本自动整理(collate)成批次。

自动批处理(默认)#

这是最常见的情况,对应于获取一批数据并将其整理为批处理样本,即包含其中一个维度为批维度(通常是第一维)的张量。

batch_size(默认 1)不为 None 时,数据加载器会产生批处理样本而不是单个样本。batch_sizedrop_last 参数用于指定数据加载器如何获取数据集键的批次。对于映射式数据集,用户还可以选择指定 batch_sampler,它每次产生一个键列表。

注意

batch_sizedrop_last 参数本质上用于从 sampler 构建 batch_sampler。对于映射式数据集,sampler 要么由用户提供,要么根据 shuffle 参数构建。对于迭代式数据集,sampler 是一个虚拟的无限采样器。有关采样器的更多详细信息,请参阅 本节

注意

当通过 多进程 获取 迭代式数据集 时,drop_last 参数会丢弃每个工作进程数据集副本的最后一个不完整的批次。

在使用来自采样器的索引获取样本列表后,传递给 collate_fn 参数的函数被用于将样本列表整理成批次。

在这种情况下,从映射式数据集加载大致等同于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

而从迭代式数据集加载大致等同于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

自定义的 collate_fn 可用于自定义整理方式,例如将序列数据填充到批次的最大长度。有关 collate_fn 的更多信息,请参阅 本节

禁用自动批处理#

在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者只需加载单个样本。例如,直接加载批处理数据(例如从数据库批量读取或读取连续的内存块)可能成本更低,或者批大小依赖于数据,或者程序旨在处理单个样本。在这些场景下,最好不要使用自动批处理(即使用 collate_fn 来整理样本),而是让数据加载器直接返回 dataset 对象的每个成员。

batch_sizebatch_sampler 均为 None 时(batch_sampler 的默认值已为 None),自动批处理被禁用。从 dataset 获取的每个样本都会使用作为 collate_fn 参数传递的函数进行处理。

当禁用自动批处理时,默认的 collate_fn 仅将 NumPy 数组转换为 PyTorch 张量,并保持其他所有内容不变。

在这种情况下,从映射式数据集加载大致等同于:

for index in sampler:
    yield collate_fn(dataset[index])

而从迭代式数据集加载大致等同于:

for data in iter(dataset):
    yield collate_fn(data)

有关 collate_fn 的更多信息,请参阅 本节

使用 collate_fn#

在启用或禁用自动批处理时,collate_fn 的用法略有不同。

当禁用自动批处理时collate_fn 会被每个单独的数据样本调用,输出结果由数据加载器迭代器产生。在这种情况下,默认的 collate_fn 只是将 NumPy 数组转换为 PyTorch 张量。

当启用自动批处理时collate_fn 每次被调用时都会接收一个数据样本列表。它期望将输入的样本整理成一个批次,以便由数据加载器迭代器产生。本节的其余部分描述了默认 collate_fn (default_collate()) 的行为。

例如,如果每个数据样本由一个 3 通道图像和一个整数类标签组成,即数据集的每个元素返回一个元组 (image, class_index),默认的 collate_fn 会将这些元组的列表整理成一个包含批处理图像张量和批处理类标签张量的单一元组。特别地,默认的 collate_fn 具有以下属性:

  • 它总是添加一个新的维度作为批维度。

  • 它自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。

  • 它保留数据结构,例如,如果每个样本是一个字典,它输出一个具有相同键集但值为批处理张量(如果值不能转换为张量,则为列表)的字典。对于 listtuplenamedtuple 等也是如此。

用户可以使用自定义的 collate_fn 来实现自定义批处理,例如沿除第一维之外的维度进行整理、填充不同长度的序列或为自定义数据类型添加支持。

如果您发现 DataLoader 的输出维度或类型与预期不符,您可能需要检查您的 collate_fn

单进程和多进程数据加载#

DataLoader 默认使用单进程数据加载。

在 Python 进程中,全局解释器锁 (GIL) 阻止了 Python 代码在线程间进行真正的完全并行化。为了避免数据加载阻塞计算代码,PyTorch 提供了一个简单的开关,只需将 num_workers 参数设置为正整数即可进行多进程数据加载。

单进程数据加载(默认)#

在此模式下,数据获取与 DataLoader 的初始化在同一个进程中完成。因此,数据加载可能会阻塞计算。然而,当用于进程间共享数据(例如共享内存、文件描述符)的资源有限,或者整个数据集很小且可以完全加载到内存中时,可能更倾向于使用此模式。此外,单进程加载通常显示出更易读的错误追踪,因此对于调试非常有用。

多进程数据加载#

num_workers 参数设置为正整数将开启多进程数据加载,并使用指定数量的工作进程。

警告

经过多次迭代,工作进程将消耗与父进程相同的 CPU 内存量(针对父进程中被工作进程访问的所有 Python 对象)。如果 dataset 包含大量数据(例如在数据集构建时加载了非常大的文件名列表)和/或使用了大量工作进程(总体内存使用量为 工作进程数量 * 父进程大小),这可能会带来问题。最简单的解决方法是将 Python 对象替换为不使用引用计数的表示形式,例如 Pandas、Numpy 或 PyArrow 对象。请查看 issue #13246 以获取有关为何发生这种情况的更多详细信息以及解决这些问题的示例代码。

在此模式下,每次创建 DataLoader 的迭代器(例如调用 enumerate(dataloader))时,都会创建 num_workers 个工作进程。此时,datasetcollate_fnworker_init_fn 被传递给每个工作进程,并在那里用于初始化和获取数据。这意味着数据集访问及其内部 IO、转换(包括 collate_fn)都在工作进程中运行。

torch.utils.data.get_worker_info() 在工作进程中返回各种有用的信息(包括 worker id、数据集副本、初始种子等),而在主进程中返回 None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。例如,这在对数据集进行分片(sharding)时特别有帮助。

对于映射式数据集,主进程使用 sampler 生成索引并将其发送给工作进程。因此,任何随机洗牌都在主进程中完成,通过分配要加载的索引来指导加载。

对于迭代式数据集,由于每个工作进程都会获得一个 dataset 对象的副本,因此天真的多进程加载通常会导致数据重复。使用 torch.utils.data.get_worker_info() 和/或 worker_init_fn,用户可以独立配置每个副本。(有关如何实现这一点,请参阅 IterableDataset 文档。)出于类似的原因,在多进程加载中,drop_last 参数会丢弃每个 worker 的迭代式数据集副本中的最后一个不完整的批次。

一旦达到迭代结束或迭代器被垃圾回收,工作进程就会关闭。

警告

通常不建议在多进程加载中返回 CUDA 张量,因为在多进程中使用 CUDA 和共享 CUDA 张量存在许多细微差别(请参阅 多进程中的 CUDA)。相反,我们建议使用 自动内存锁定(即设置 pin_memory=True),这可以实现向支持 CUDA 的 GPU 进行快速数据传输。

平台特定行为#

由于工作进程依赖于 Python 的 multiprocessing,Windows 上的工作进程启动行为与 Unix 不同。

  • 在 Unix 上,Python >= 3.14 的默认 multiprocessing 启动方法是 forkserver();Python < 3.14 为 fork()。使用 fork(),子进程通常可以直接通过克隆的地址空间访问 dataset 和 Python 参数函数。这可以实现快速启动,但会导致多线程应用程序出现问题。在支持它的 Unix 平台上,forkserver() 首先启动一个单独的服务器进程,然后所有新的工作进程都由该服务器产生,这比 fork()(特别是对于线程)提供了更安全的隔离,同时避免了纯 spawn() 的一些开销。

  • 在 Windows 和 MacOS 上,spawn() 是默认的 multiprocessing 启动方法。使用 spawn() 时,会启动另一个解释器来运行您的主脚本,随后运行内部工作函数,该函数通过 pickle 序列化接收 datasetcollate_fn 和其他参数。

这种单独的序列化意味着您应该采取两个步骤来确保在使用多进程数据加载时与 Windows 兼容:

  • 将大部分主脚本的代码封装在 if __name__ == '__main__': 块中,以确保在启动每个工作进程时它不会再次运行(否则极易产生错误)。您可以将数据集和 DataLoader 实例创建逻辑放在这里,因为它不需要在工作进程中重新执行。

  • 确保任何自定义的 collate_fnworker_init_fndataset 代码都被声明为顶级定义,放在 __main__ 检查之外。这确保它们在工作进程中可用。(这是必要的,因为函数仅作为引用被序列化,而不是 bytecode。)

多进程数据加载中的随机性#

默认情况下,每个工作进程的 PyTorch 种子将设置为 base_seed + worker_id,其中 base_seed 是主进程使用其 RNG 生成的长整数(从而强制消耗 RNG 状态)或指定的 generator。但是,初始化工作进程时,其他库的种子可能会重复,导致每个工作进程返回相同的随机数。(请参阅 FAQ 中的 本节。)

worker_init_fn 中,您可以使用 torch.utils.data.get_worker_info().seedtorch.initial_seed() 访问为每个工作进程设置的 PyTorch 种子,并在加载数据之前将其用于初始化其他库的种子。

内存锁定(Memory Pinning)#

当主机到 GPU 的拷贝源自锁定(分页锁定)内存时,它们的速度会快得多。有关何时以及如何使用锁定内存的更多详细信息,请参阅 使用锁定内存缓冲区

对于数据加载,将 pin_memory=True 传递给 DataLoader 将自动将获取的数据张量置于锁定内存中,从而实现向支持 CUDA 的 GPU 的快速数据传输。

默认的内存锁定逻辑仅识别张量以及包含张量的映射和可迭代对象。默认情况下,如果锁定逻辑看到一个自定义类型的批次(如果您有一个返回自定义批处理类型的 collate_fn,就会发生这种情况),或者如果您的批次中的每个元素都是自定义类型,则锁定逻辑将无法识别它们,并会原样返回该批次(或那些元素)而不进行内存锁定。要为自定义批次或数据类型启用内存锁定,请在您的自定义类型上定义一个 pin_memory() 方法。

请参阅下面的示例。

示例

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='', in_order=True)[源码]#

数据加载器将数据集和采样器组合在一起,并提供给定数据集上的可迭代对象。

DataLoader 支持映射式(map-style)和可迭代式(iterable-style)数据集,支持单进程或多进程加载、自定义加载顺序,以及可选的自动批处理(整理)和内存固定(memory pinning)。

更多详情请参见 torch.utils.data 文档页面。

参数:
  • dataset (Dataset) – 用于加载数据的数据集。

  • batch_size (int, 可选) – 每个批次要加载的样本数量(默认:1)。

  • shuffle (bool, 可选) – 设置为 True 可在每个 epoch 重新打乱数据(默认:False)。

  • sampler (SamplerIterable, 可选) – 定义从数据集中抽取样本的策略。可以是任何实现了 __len__Iterable。如果指定了此参数,则不能指定 shuffle

  • batch_sampler (SamplerIterable, 可选) – 类似于 sampler,但每次返回一批索引。与 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, 可选) – 用于数据加载的子进程数量。0 表示数据将在主进程中加载。(默认:0

  • collate_fn (Callable, 可选) – 将样本列表合并以形成张量(Tensor)的小批次(mini-batch)。在使用映射式数据集进行批量加载时使用。

  • pin_memory (bool, 可选) – 如果为 True,数据加载器将在返回前将张量复制到设备/CUDA 的固定内存(pinned memory)中。如果您的数据元素是自定义类型,或者 collate_fn 返回的是自定义类型的批次,请参考下面的示例。

  • drop_last (bool, 可选) – 如果数据集大小不能被批次大小整除,设置为 True 可丢弃最后一个不完整的批次。如果为 False 且数据集大小不能被整除,则最后一个批次会更小。(默认:False

  • timeout (数值型, 可选) – 如果为正数,表示从工作进程收集批次的超时值。必须始终为非负数。(默认:0

  • worker_init_fn (Callable, 可选) – 如果不为 None,则在每个工作子进程中(设置随机种子后、数据加载前)调用此函数,并传入工作进程 ID(一个范围在 [0, num_workers - 1] 之间的整数)。(默认:None

  • multiprocessing_context (strmultiprocessing.context.BaseContext, 可选) – 如果为 None,将使用操作系统的默认 多进程上下文。(默认:None

  • generator (torch.Generator, 可选) – 如果不为 None,此 RNG(随机数生成器)将用于 RandomSampler 生成随机索引,并用于多进程生成工作进程的 base_seed。(默认:None

  • prefetch_factor (int, 可选, 仅限关键字参数) – 每个工作进程提前加载的批次数量。2 表示所有工作进程总共预取 2 * num_workers 个批次。(默认值取决于 num_workers 的设置。如果 num_workers=0,则默认为 None。否则,若 num_workers > 0,默认为 2)。

  • persistent_workers (bool, 可选) – 如果为 True,数据加载器在数据集被消耗一次后不会关闭工作进程。这可以保持工作进程的 Dataset 实例处于活跃状态。(默认:False

  • pin_memory_device (str, 可选) – 已弃用,如果设置了 pin_memory=True,当前的 加速器 (accelerator) 将被用作设备。

  • in_order (bool, 可选) – 如果为 False,数据加载器将不再强制要求按先进先出 (FIFO) 的顺序返回批次。仅在 num_workers > 0 时有效。(默认:True

警告

如果使用了 spawn 启动方法,worker_init_fn 不能是不可序列化的对象(例如 lambda 函数)。有关 PyTorch 中多进程的更多详细信息,请参阅 多进程最佳实践

警告

len(dataloader) 的启发式计算基于所使用的采样器的长度。当 datasetIterableDataset 时,它会返回基于 len(dataset) / batch_size 的估算值,并根据 drop_last 进行适当的舍入,无论多进程加载配置如何。这是 PyTorch 能做出的最佳猜测,因为它信任用户提供的 dataset 代码能正确处理多进程加载以避免重复数据。

然而,如果分片(sharding)导致多个工作进程有不完整的最后一个批次,此估算值仍可能不准确,因为 (1) 一个原本完整的批次可能被拆分为多个,以及 (2) 当设置了 drop_last 时,可能会丢弃超过一个批次的样本。遗憾的是,PyTorch 通常无法检测此类情况。

有关这两类数据集的更多详细信息,以及 IterableDataset 如何与 多进程数据加载 交互,请参见 数据集类型 (Dataset Types)

警告

有关随机种子相关的问题,请参阅 可复现性 (Reproducibility)我的数据加载器工作进程返回相同的随机数 以及 多进程数据加载中的随机性 等注意事项。

警告

in_order 设置为 False 可能会损害可复现性,并在数据不平衡的情况下可能导致输入给训练器的数据分布出现偏差。

class torch.utils.data.Dataset[源码]#

表示 Dataset 的抽象类。

所有表示从键到数据样本的映射的数据集都应继承此类。所有子类都应重写 __getitem__(),以支持根据给定键获取数据样本。子类还可以选择性地重写 __len__(),许多 Sampler 实现和 DataLoader 的默认选项都期望该方法返回数据集的大小。子类还可以选择性地实现 __getitems__(),以加速批量样本加载。该方法接受一批样本的索引列表并返回对应的样本列表。

注意

DataLoader 默认构造一个生成整数索引的索引采样器。若要使其与具有非整数索引/键的映射式数据集一起工作,必须提供自定义采样器。

class torch.utils.data.IterableDataset[源码]#

一种可迭代的数据集。

所有表示数据样本可迭代对象的数据集都应继承此类。当数据来自数据流时,这种形式的数据集特别有用。

所有子类都应重写 __iter__(),该方法应返回此数据集中样本的迭代器。

当子类与 DataLoader 一起使用时,数据集中的每一项都将从 DataLoader 迭代器中产出。当 num_workers > 0 时,每个工作进程都会拥有数据集对象的副本,因此通常需要独立配置每个副本,以避免从工作进程返回重复数据。在工作进程中调用 get_worker_info() 可返回有关该工作进程的信息。它可以在数据集的 __iter__() 方法或 DataLoaderworker_init_fn 选项中使用,以修改每个副本的行为。

示例 1:在 __iter__() 中跨工作进程拆分工作负载

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

>>> # Multi-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

示例 2:使用 worker_init_fn 跨工作进程拆分工作负载

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)[源码]#

包装张量的数据集。

每个样本将通过沿第一维度索引张量来获取。

参数:

*tensors (Tensor) – 第一维度大小相同的张量。

class torch.utils.data.StackDataset(*args, **kwargs)[源码]#

作为多个数据集堆叠的数据集。

该类用于组装作为数据集提供的复杂输入数据的不同部分。

示例

>>> images = ImageDataset()
>>> texts = TextDataset()
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {"image": images[0], "text": texts[0]}
参数:
  • *args (Dataset) – 以元组形式返回堆叠后的数据集。

  • **kwargs (Dataset) – 以字典形式返回堆叠后的数据集。

class torch.utils.data.ConcatDataset(datasets)[源码]#

作为多个数据集级联的数据集。

该类用于组装不同的现有数据集。

参数:

datasets (序列) – 待级联的数据集列表

class torch.utils.data.ChainDataset(datasets)[源码]#

用于链接多个 IterableDataset 的数据集。

该类对于组装不同的现有数据集流非常有用。链接操作是动态进行的,因此使用此类级联大规模数据集将非常高效。

参数:

datasets (可迭代对象, 包含 IterableDataset) – 待链接在一起的数据集

class torch.utils.data.Subset(dataset, indices)[源码]#

数据集在指定索引处的子集。

注意

继承 Subset 并重写 __getitem__ 时,您 必须 同时重写 __getitems__ 以确保 DataLoader 能正确处理您的自定义逻辑。如果您仅重写了 __getitem__,在使用 DataLoader 时将引发 NotImplementedError

__getitems__ 的简单实现可以委托给 __getitem__

def __getitems__(self, indices):
    return [self.__getitem__(idx) for idx in indices]

为了获得更好的性能,请考虑在 __getitems__ 中实现批处理感知逻辑,而不是多次调用 __getitem__

参数:
  • dataset (Dataset) – 完整的数据集

  • indices (序列) – 选定用于子集的完整集中的索引

torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[源码]#

通用的整理函数,用于处理每个批次中集合类型的元素。

该函数还开放了一个函数注册表以处理特定元素类型。default_collate_fn_map 提供了针对张量、NumPy 数组、数字和字符串的默认整理函数。

参数:
  • batch – 待整理的单个批次

  • collate_fn_map (dict[type | tuple[type, ...], Callable] | None) – 从元素类型到相应整理函数的映射字典(可选)。如果元素类型未在此字典中,则该函数将按插入顺序遍历字典的每个键,并在元素类型为该键的子类时调用相应的整理函数。

示例

>>> def collate_tensor_fn(batch, *, collate_fn_map):
...     # Extend this function to handle batch of tensors
...     return torch.stack(batch, 0)
>>> def custom_collate(batch):
...     collate_map = {torch.Tensor: collate_tensor_fn}
...     return collate(batch, collate_fn_map=collate_map)
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})

注意

每个整理函数都需要一个位置参数用于传入批次,以及一个关键字参数用于传入整理函数映射字典(即 collate_fn_map)。

torch.utils.data.default_collate(batch)[源码]#

接收一批数据,并将批次中的元素放入一个具有额外外部维度(批次大小)的张量中。

确切的输出类型可以是 torch.Tensortorch.TensorSequencetorch.Tensor 的 Collection,或者保持不变(取决于输入类型)。当在 DataLoader 中定义了 batch_sizebatch_sampler 时,此函数用作默认的整理函数。

以下是输入类型(基于批次内元素的类型)到输出类型的通用映射:

  • torch.Tensor -> torch.Tensor(添加了一个批次大小的外部维度)

  • NumPy 数组 -> torch.Tensor

  • float -> torch.Tensor

  • int -> torch.Tensor

  • str -> str(不变)

  • bytes -> bytes(不变)

  • Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

  • NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

  • Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

参数:

batch – 待整理的单个批次

示例

>>> # Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s:
>>> default_collate(["a", "b", "c"])
['a', 'b', 'c']
>>> # Example with `Map` inside the batch:
>>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
>>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple("Point", ["x", "y"])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]
>>> # Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]
>>> # Two options to extend `default_collate` to handle specific type
>>> # Option 1: Write custom collate function and invoke `default_collate`
>>> def custom_collate(batch):
...     elem = batch[0]
...     if isinstance(elem, CustomType):  # Some custom condition
...         return ...
...     else:  # Fall back to `default_collate`
...         return default_collate(batch)
>>> # Option 2: In-place modify `default_collate_fn_map`
>>> def collate_customtype_fn(batch, *, collate_fn_map=None):
...     return ...
>>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
>>> default_collate(batch)  # Handle `CustomType` automatically
torch.utils.data.default_convert(data)[源码]#

将每个 NumPy 数组元素转换为 torch.Tensor

如果输入是 SequenceCollectionMapping,它会尝试将其中的每个元素转换为 torch.Tensor。如果输入不是 NumPy 数组,则保持不变。当 DataLoader 中未定义 batch_samplerbatch_size 时,此函数用作默认的整理函数。

其输入类型到输出类型的通用映射与 default_collate() 类似。有关更多详情,请参见该函数的描述。

参数:

data – 待转换的单个数据点

示例

>>> # Example with `int`
>>> default_convert(0)
0
>>> # Example with NumPy array
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # Example with NamedTuple
>>> Point = namedtuple("Point", ["x", "y"])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
Point(x=tensor(0), y=tensor(0))
>>> # Example with List
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
[tensor([0, 1]), tensor([2, 3])]
torch.utils.data.get_worker_info()[源码]#

返回有关当前 DataLoader 迭代器工作进程的信息。

在工作进程中调用时,返回的对象保证具有以下属性:

  • id:当前工作进程 ID。

  • num_workers:工作进程的总数。

  • seed:为当前工作进程设置的随机种子。该值由主进程的随机数生成器(RNG)和工作进程 ID 决定。更多详细信息,请参阅 DataLoader 的文档。

  • dataset当前进程中数据集对象的副本。请注意,在不同进程中,这将是与主进程中不同的对象。

当在主进程中调用时,此方法返回 None

注意

当在传递给 DataLoaderworker_init_fn 中使用时,该方法可用于对每个工作进程进行不同的设置。例如,利用 worker_id 配置 dataset 对象以仅读取分片数据集的特定部分,或者使用 seed 为数据集代码中使用的其他库设置种子。

返回类型:

WorkerInfo | None

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[source]#

将数据集随机拆分为给定长度的、不重叠的新数据集。

如果给出一组总和为 1 的比例(fractions),则长度将自动计算为 floor(frac * len(dataset))。

计算长度后,如果有余数,将以轮询(round-robin)的方式分配 1 个计数,直到没有余数为止。

可以选择固定生成器以获得可重复的结果,例如:

示例

>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
参数:
  • dataset (Dataset) – 要拆分的数据集

  • lengths (sequence) – 要生成的拆分长度或比例

  • generator (Generator) – 用于随机排列的生成器。

返回类型:

list[Subset[_T]]

class torch.utils.data.Sampler[source]#

所有采样器(Samplers)的基类。

每个 Sampler 子类都必须提供一个 __iter__() 方法,提供一种遍历数据集元素索引或索引列表(批次)的方式,并且可以提供一个返回迭代器长度的 __len__() 方法。

示例

>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

注意

__len__() 方法并非 DataLoader 严格要求,但任何涉及 DataLoader 长度的计算都会用到它。

class torch.utils.data.SequentialSampler(data_source)[source]#

按顺序对元素进行采样,始终保持相同的顺序。

参数:

data_source (Sized) – 要从中采样的数据源。必须实现 __len__。

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[source]#

随机采样元素。如果不进行有放回采样(replacement=False),则从洗牌后的数据集中进行采样。

如果进行有放回采样,用户可以指定要抽取的 num_samples

参数:
  • data_source (Sized) – 要从中采样的数据源。必须实现 __len__。

  • replacement (bool) – 若为 True,则按需进行有放回采样,默认值为 False

  • num_samples (int) – 要抽取的样本数,默认值为 len(dataset)

  • generator (Generator) – 用于采样的生成器。

class torch.utils.data.SubsetRandomSampler(indices, generator=None)[source]#

从给定的索引列表中随机采样元素,无放回。

参数:
  • indices (sequence) – 索引序列

  • generator (Generator) – 用于采样的生成器。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[source]#

按给定的概率(权重)从 [0,..,len(weights)-1] 中采样元素。

参数:
  • weights (sequence) – 权重序列,不必加总为 1。

  • num_samples (int) – 要抽取的样本数。

  • replacement (bool) – 若为 True,则有放回采样。若为 False,则无放回采样,这意味着如果一个样本索引被某行选中,它在该行中将不能再次被选中。

  • generator (Generator) – 用于采样的生成器。

示例

>>> list(
...     WeightedRandomSampler(
...         [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True
...     )
... )
[4, 4, 1, 4, 5]
>>> list(
...     WeightedRandomSampler(
...         [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False
...     )
... )
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source]#

包装另一个采样器以生成索引的微批次(mini-batch)。

参数:
  • sampler (Sampler or Iterable) – 基础采样器。可以是任何可迭代对象。

  • batch_size (int) – 微批次的大小。

  • drop_last (bool) – 若为 True,则当最后一个批次的大小小于 batch_size 时,采样器将丢弃它。

示例

>>> list(
...     BatchSampler(
...         SequentialSampler(range(10)), batch_size=3, drop_last=False
...     )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(
...     BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source]#

将数据加载限制为数据集子集的采样器。

它特别适用于与 torch.nn.parallel.DistributedDataParallel 结合使用。在这种情况下,每个进程可以将一个 DistributedSampler 实例作为 DataLoader 的采样器,并加载原始数据集专属于它的子集。

注意

假设数据集大小是恒定的,并且其实例始终以相同的顺序返回相同的元素。

参数:
  • dataset (Dataset) – 用于采样的数据集。

  • num_replicas (int, optional) – 参与分布式训练的进程数。默认情况下,world_size 从当前分布式组中检索。

  • rank (int, optional) – 当前进程在 num_replicas 中的排名。默认情况下,rank 从当前分布式组中检索。

  • shuffle (bool, optional) – 若为 True(默认值),采样器将打乱索引。

  • seed (int, optional) – 如果 shuffle=True,则用于打乱采样器的随机种子。此数字在分布式组的所有进程中应相同。默认值:0

  • drop_last (bool, optional) – 若为 True,则采样器将丢弃数据尾部,以使数据在各副本之间均匀分配。若为 False,则采样器将添加额外的索引,以使数据在各副本之间均匀分配。默认值:False

警告

在分布式模式下,为了使打乱操作在多个 epoch 之间正常工作,必须在创建 DataLoader 迭代器之前在每个 epoch 开始时调用 set_epoch() 方法。否则,将始终使用相同的顺序。

示例

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)