torch.utils.data#
创建于: 2025年6月13日 | 最后更新于: 2025年6月13日
PyTorch 数据加载工具的核心是 torch.utils.data.DataLoader 类。它代表了一个数据集上的 Python 可迭代对象,支持
这些选项由 DataLoader 的构造函数参数配置,其签名如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
下面的部分将详细介绍这些选项的效果和用法。
数据集类型#
DataLoader 构造函数中最重要的参数是 dataset,它指定了要从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:
基于映射的数据集#
基于映射的数据集是指实现了 __getitem__() 和 __len__() 协议的数据集,它代表一个从(可能的非整型)索引/键到数据样本的映射。
例如,这样的数据集在被访问时,如使用 dataset[idx],可以从磁盘上的文件夹中读取第 idx 个图像及其对应的标签。
更多详细信息请参见 Dataset。
基于迭代器的数据集#
基于迭代器的数据集是 IterableDataset 的一个子类实例,它实现了 __iter__() 协议,代表一个数据样本的可迭代对象。这类数据集特别适用于随机读取成本高昂甚至不可能实现,且批次大小依赖于获取到的数据的情况。
例如,当调用 iter(dataset) 时,这样的数据集可以返回一个从数据库、远程服务器甚至实时生成的日志中读取数据的流。
更多详细信息请参见 IterableDataset。
注意
在使用带有 多进程数据加载 的 IterableDataset 时。相同的 IterableDataset 对象会在每个工作进程中复制,因此必须对这些副本进行不同的配置以避免数据重复。请参阅 IterableDataset 文档了解如何实现这一点。
数据加载顺序和 Sampler#
对于 基于迭代器的数据集,数据加载顺序完全由用户定义的迭代器控制。这使得实现分块读取和动态批次大小(例如,每次产生一个批次的样本)变得更加容易。
本节的其余部分关注 基于映射的数据集 的情况。torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的序列。它们代表数据集索引的可迭代对象。例如,在随机梯度下降 (SGD) 的常见情况下,Sampler 可以随机置换索引列表并逐个生成索引,或者为小批量 SGD 生成少量索引。
DataLoader 的 shuffle 参数会自动构造一个顺序或随机的采样器。或者,用户可以使用 sampler 参数指定一个自定义的 Sampler 对象,该对象每次生成下一个要获取的索引/键。
一个每次生成一批索引的自定义 Sampler 可以作为 batch_sampler 参数传入。也可以通过 batch_size 和 drop_last 参数启用自动批处理。有关此内容的更多详细信息,请参阅下一节。
注意
sampler 和 batch_sampler 都与基于迭代器的数据集不兼容,因为此类数据集没有索引或键的概念。
加载批处理和非批处理数据#
DataLoader 通过 batch_size、drop_last、batch_sampler 和 collate_fn(具有默认函数)参数支持自动将单独获取的数据样本整理成批次。
自动批处理(默认)#
这是最常见的情况,它对应于获取一个小批量数据并将它们整理成批次的样本,即包含一个维度是批次维度(通常是第一个)的 Tensor。
当 batch_size(默认值为 1)不为 None 时,数据加载器生成批次样本而不是单个样本。batch_size 和 drop_last 参数用于指定数据加载器如何获取数据集键的批次。对于基于映射的数据集,用户可以选择指定 batch_sampler,它一次生成一个键的列表。
注意
batch_size 和 drop_last 参数主要用于从 sampler 构造一个 batch_sampler。对于基于映射的数据集,sampler 要么由用户提供,要么根据 shuffle 参数构造。对于基于迭代器的数据集,sampler 是一个虚拟的无限采样器。有关采样器的更多详细信息,请参阅本节。
使用采样器从索引获取一批样本后,传递给 collate_fn 参数的函数用于将样本列表整理成批次。
在这种情况下,从基于映射的数据集加载大致相当于
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
而从基于迭代器的数据集加载大致相当于
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
可以使用自定义 collate_fn 来定制整理,例如将序列数据填充到批次的กล่าว最大长度。有关 collate_fn 的更多信息,请参阅本节。
禁用自动批处理#
在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者只是加载单个样本。例如,直接加载批处理数据(例如,从数据库中批量读取或读取连续的内存块)可能更便宜,或者批次大小依赖于数据,或者程序被设计为处理单个样本。在这些情况下,最好不要使用自动批处理(其中使用 collate_fn 来整理样本),而是让数据加载器直接返回 dataset 对象的每个成员。
当 batch_size 和 batch_sampler 都为 None 时(batch_sampler 的默认值已经是 None),自动批处理被禁用。从 dataset 中获取的每个样本都会使用传递给 collate_fn 参数的函数进行处理。
当禁用自动批处理时,默认的 collate_fn 简单地将 NumPy 数组转换为 PyTorch Tensor,并保持其他所有内容不变。
在这种情况下,从基于映射的数据集加载大致相当于
for index in sampler:
yield collate_fn(dataset[index])
而从基于迭代器的数据集加载大致相当于
for data in iter(dataset):
yield collate_fn(data)
有关 collate_fn 的更多信息,请参阅本节。
使用 collate_fn#
启用或禁用自动批处理时,collate_fn 的用法略有不同。
当禁用自动批处理时,collate_fn 会针对每个单独的数据样本调用,输出将从数据加载器迭代器中产生。在这种情况下,默认的 collate_fn 仅将 NumPy 数组转换为 PyTorch Tensor。
当启用自动批处理时,collate_fn 每次都会接收一个数据样本列表。它应该将输入的样本整理成一个批次,然后从数据加载器迭代器中产生。本节的其余部分描述了默认 collate_fn(default_collate())的行为。
例如,如果每个数据样本由一个 3 通道图像和一个整数类别标签组成,即数据集的每个元素返回一个元组 (image, class_index),则默认的 collate_fn 会将这样一个元组列表整理成一个包含一个批次图像 Tensor 和一个批次类别标签 Tensor 的单一元组。特别是,默认的 collate_fn 具有以下特性:
它总是在前面添加一个新维度作为批次维度。
它会自动将 NumPy 数组和 Python 数值转换为 PyTorch Tensor。
它保留数据结构,例如,如果每个样本是一个字典,它将输出一个具有相同键集的字典,但值为批次的 Tensor(如果值不能转换为 Tensor,则为列表)。
list、tuple、namedtuple等也是如此。
用户可以使用自定义的 collate_fn 来实现自定义批处理,例如,沿着第一个维度以外的维度进行整理、填充可变长度的序列,或为自定义数据类型添加支持。
如果您发现 DataLoader 的输出维度或类型与您的预期不同,您可能需要检查您的 collate_fn。
单进程和多进程数据加载#
DataLoader 默认使用单进程数据加载。
在 Python 进程内,全局解释器锁 (GIL) 阻止了在线程之间对 Python 代码进行真正的并行化。为避免数据加载阻塞计算代码,PyTorch 通过简单地将参数 num_workers 设置为正整数,提供了一个方便的开关来执行多进程数据加载。
单进程数据加载(默认)#
在此模式下,数据获取在初始化 DataLoader 的同一进程中完成。因此,数据加载可能会阻塞计算。然而,当用于在进程间共享数据(例如共享内存、文件描述符)的资源有限,或者整个数据集很小并且可以完全加载到内存中时,此模式可能更受欢迎。此外,单进程加载通常会显示出更易读的错误跟踪,因此对调试很有用。
多进程数据加载#
将参数 num_workers 设置为一个正整数将以指定的加载器工作进程数开启多进程数据加载。
警告
经过几次迭代后,加载器工作进程将消耗与父进程相同的 CPU 内存用于父进程中从工作进程访问的所有 Python 对象。如果 Dataset 包含大量数据(例如,您在 Dataset 构建时加载了非常大的文件名列表)和/或您使用了许多工作进程(总内存使用量为 工作进程数 * 父进程大小),这可能会很麻烦。最简单的解决方法是将 Python 对象替换为非引用计数的表示形式,如 Pandas、Numpy 或 PyArrow 对象。请查看 issue #13246 以了解有关此问题发生原因的更多详细信息以及解决这些问题的示例代码。
在此模式下,每次创建 DataLoader 的迭代器时(例如,当您调用 enumerate(dataloader) 时),将创建 num_workers 个工作进程。此时,dataset、collate_fn 和 worker_init_fn 会传递给每个工作进程,并在其中用于初始化和获取数据。这意味着数据集访问及其内部 I/O、转换(包括 collate_fn)都在工作进程中运行。
torch.utils.data.get_worker_info() 在工作进程中返回各种有用的信息(包括工作进程 ID、数据集副本、初始种子等),并在主进程中返回 None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数来独立配置每个数据集副本,并确定代码是在工作进程中运行。例如,这在分片数据集时特别有用。
对于基于映射的数据集,主进程使用 sampler 生成索引并将其发送给工作进程。因此,任何随机洗牌操作都在主进程中完成,主进程通过分配要加载的索引来指导加载。
对于基于迭代器的数据集,由于每个工作进程都会获得 dataset 对象的一个副本,因此简单地进行多进程加载通常会导致数据重复。使用 torch.utils.data.get_worker_info() 和/或 worker_init_fn,用户可以独立配置每个副本。(有关如何实现这一点,请参阅 IterableDataset 文档。)出于类似的原因,在多进程加载中,drop_last 参数会丢弃每个工作进程基于迭代器的数据集副本的最后一个未满批次。
一旦迭代结束或迭代器被垃圾回收,工作进程就会被关闭。
警告
通常不建议在多进程加载中返回 CUDA Tensor,因为在多进程中使用 CUDA 和共享 CUDA Tensor 时存在许多微妙之处(请参阅多进程中的 CUDA)。相反,我们建议使用自动内存锁定(即设置 pin_memory=True),这可以实现快速的数据传输到启用了 CUDA 的 GPU。
特定于平台的行为#
由于工作进程依赖于 Python 的 multiprocessing,工作进程的启动行为在 Windows 与 Unix 上是不同的。
在 Unix 上,
fork()是默认的multiprocessing启动方法。使用fork()时,子工作进程通常可以通过克隆的地址空间直接访问dataset和 Python 参数函数。在 Windows 或 MacOS 上,
spawn()是默认的multiprocessing启动方法。使用spawn()时,会启动另一个解释器来运行主脚本,然后是内部工作函数,该函数通过pickle序列化接收dataset、collate_fn和其他参数。
这种单独的序列化意味着,为了确保在使用多进程数据加载时与 Windows 兼容,您需要采取两个步骤:
将主脚本代码的大部分内容包装在
if __name__ == '__main__':块中,以确保在启动每个工作进程时它不会再次运行(这很可能会导致错误)。您可以将数据集和DataLoader实例创建逻辑放在此处,因为它们不需要在工作进程中重新执行。确保任何自定义的
collate_fn、worker_init_fn或dataset代码被声明为顶级定义,位于__main__检查之外。这确保了它们在工作进程中可用。(这是必需的,因为函数仅作为引用而不是字节码被 pickling。)
多进程数据加载中的随机性#
默认情况下,每个工作进程的 PyTorch 种子将设置为 base_seed + worker_id,其中 base_seed 是主进程使用其 RNG 生成的长整数(从而强制消耗一个 RNG 状态)或指定的 generator。然而,其他库的种子在初始化工作进程时可能会重复,导致每个工作进程返回相同的随机数。(请参阅 FAQ 中的本节。)
在 worker_init_fn 中,您可以通过 torch.utils.data.get_worker_info().seed 或 torch.initial_seed() 访问为每个工作进程设置的 PyTorch 种子,并使用它在数据加载前为其他库设置种子。
内存锁定#
当主机到 GPU 的复制操作源自锁定的(页面锁定)内存时,速度会快得多。有关何时以及如何使用锁定内存的一般信息,请参阅使用锁定的内存缓冲区。
为了加载数据,将 pin_memory=True 传递给 DataLoader 会自动将获取到的数据 Tensor 放入固化内存中,从而加速数据到支持 CUDA 的 GPU 的传输。
默认的内存固化逻辑仅识别 Tensor 以及包含 Tensor 的映射和可迭代对象。默认情况下,如果固化逻辑检测到一个批次是自定义类型(当你有一个返回自定义批次类型的 collate_fn 时就会发生这种情况),或者你的批次中的每个元素都是自定义类型,固化逻辑将无法识别它们,并且会原样返回该批次(或这些元素),而不进行内存固化。要为自定义批次或数据类型启用内存固化,请在你自定义类型的上定义一个 pin_memory() 方法。
请参阅以下示例。
示例
class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
- class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='', in_order=True)[source]#
DataLoader 组合了 dataset 和 sampler,并提供了一个可迭代的数据集。
DataLoader支持 map 风格和 iterable 风格的数据集,并支持单进程或多进程加载、自定义加载顺序以及可选的自动批处理(collation)和内存固化。有关更多详细信息,请参阅
torch.utils.data文档页面。- 参数:
dataset (Dataset) – 要从中加载数据的数据集。
batch_size (int, optional) – 每次加载的样本数(默认:
1)。shuffle (bool, optional) – 设置为
True以在每个 epoch 都重新洗牌数据(默认:False)。sampler (Sampler 或 Iterable, optional) – 定义从数据集中抽取样本的策略。可以是任何实现了
__len__的Iterable。如果指定了该参数,则必须不能指定shuffle。batch_sampler (Sampler 或 Iterable, optional) – 类似于
sampler,但一次返回一个批次的索引。与batch_size、shuffle、sampler和drop_last互斥。num_workers (int, optional) – 用于数据加载的子进程数量。
0表示数据将在主进程中加载。(默认:0)collate_fn (Callable, optional) – 将样本列表合并以形成一个 Tensor 批次。在从 map 风格数据集使用批处理加载时使用。
pin_memory (bool, optional) – 如果为
True,则数据加载器将在返回 Tensor 之前将其复制到设备/CUDA 固化内存中。如果你的数据元素是自定义类型,或者你的collate_fn返回一个自定义类型的批次,请参阅下面的示例。drop_last (bool, optional) – 如果数据集大小不能被 batch_size 整除,则设置为
True以丢弃最后一个不完整的批次。如果设置为False且数据集大小不能被 batch_size 整除,则最后一个批次会更小。(默认:False)timeout (numeric, optional) – 如果为正数,则表示从工作进程收集批次的超时值。应始终为非负数。(默认:
0)worker_init_fn (Callable, optional) – 如果不为
None,则将在每个工作进程子进程中调用,输入为工作进程 ID(一个介于[0, num_workers - 1]之间的整数),在设置随机种子之后,加载数据之前。(默认:None)multiprocessing_context (str 或 multiprocessing.context.BaseContext, optional) – 如果为
None,则将使用操作系统默认的 多进程上下文 # noqa: D401。(默认:None)generator (torch.Generator, optional) – 如果不为
None,则 RandomSampler 将使用此 RNG 来生成随机索引,并且多进程将使用此 RNG 来为工作进程生成base_seed。(默认:None)prefetch_factor (int, optional, keyword-only arg) – 每个工作进程预加载的批次数。
2表示所有工作进程总共将预加载 2 * num_workers 个批次。(默认值取决于 num_workers 的设置。如果 num_workers=0,则默认值为None。否则,如果num_workers > 0,则默认值为2)。persistent_workers (bool, optional) – 如果为
True,则数据加载器在数据集被消耗一次后不会关闭工作进程。这允许将工作进程的 Dataset 实例保持活动状态。(默认:False)pin_memory_device (str, optional) – 已弃用,如果
pin_memory=True,则将使用当前 加速器 作为设备。in_order (bool, optional) – 如果为
False,则数据加载器将不强制按先入先出顺序返回批次。仅在num_workers > 0时适用。(默认:True)
警告
如果使用了
spawn启动方法,则worker_init_fn不能是不可序列化的对象,例如 lambda 函数。有关 PyTorch 中多进程的更多详细信息,请参阅 多进程最佳实践。警告
len(dataloader)的启发式方法基于所使用的 sampler 的长度。当dataset是IterableDataset时,它会根据len(dataset) / batch_size返回一个估算值,并根据drop_last进行适当的四舍五入,而忽略多进程加载配置。这是 PyTorch 能做的最佳猜测,因为 PyTorch 信任用户dataset代码能正确处理多进程加载以避免重复数据。然而,如果分片导致多个工作进程具有不完整的最后一个批次,此估算值可能仍然不准确,因为(1)一个原本完整的批次可以被分解成多个批次,以及(2)当设置
drop_last时,可能会丢弃一个批次以上的样本。不幸的是,PyTorch 通常无法检测到这种情况。有关这两种数据集类型以及
IterableDataset如何与 多进程数据加载 交互的更多信息,请参阅 数据集类型。警告
有关与随机种子相关的问题,请参阅 可复现性、我的数据加载器工作进程返回相同的随机数 和 多进程数据加载中的随机性 说明。
警告
将 in_order 设置为 False 可能会损害可复现性,并在数据不平衡的情况下导致训练器接收到有偏差的数据分布。
- class torch.utils.data.Dataset[source]#
代表
Dataset的抽象类。所有表示键到数据样本映射的数据集都应继承此类。所有子类都应重写
__getitem__(),以支持根据给定键获取数据样本。子类还可以选择性地重写__len__(),许多Sampler实现和DataLoader的默认选项都期望它返回数据集的大小。子类还可以选择性地实现__getitems__(),以加速批处理样本加载。此方法接受样本批次的索引列表,并返回样本列表。注意
DataLoader默认构造一个生成整数索引的索引 sampler。要使其与具有非整数索引/键的 map 风格数据集一起工作,必须提供一个自定义 sampler。
- class torch.utils.data.IterableDataset[source]#
可迭代的数据集。
所有表示数据样本可迭代对象的数据集都应继承此类。这种形式的数据集在数据来自流时特别有用。
所有子类都应重写
__iter__(),它将返回此数据集中样本的迭代器。当子类与
DataLoader一起使用时,数据集中的每个项都将从DataLoader迭代器中生成。当num_workers > 0时,每个工作进程将拥有数据集对象的不同副本,因此通常希望独立配置每个副本以避免工作进程返回重复数据。get_worker_info()在工作进程中调用时,会返回有关工作进程的信息。它可以在数据集的__iter__()方法或DataLoader的worker_init_fn选项中使用,以修改每个副本的行为。示例 1:在
__iter__()中拆分所有工作进程的工作量>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # Multi-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
示例 2:使用
worker_init_fn拆分所有工作进程的工作量>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]
- class torch.utils.data.TensorDataset(*tensors)[source]#
包装 Tensor 的数据集。
每个样本将通过沿第一个维度索引 Tensor 来检索。
- 参数:
*tensors (Tensor) – 与第一个维度大小相同的 Tensor。
- class torch.utils.data.StackDataset(*args, **kwargs)[source]#
堆叠多个数据集的数据集。
此类有助于组合作为数据集的不同输入数据部分。
示例
>>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {"image": images[0], "text": texts[0]}
- class torch.utils.data.ConcatDataset(datasets)[source]#
多个数据集的连接数据集。
此类有助于组合不同的现有数据集。
- 参数:
datasets (sequence) – 要连接的数据集列表
- class torch.utils.data.ChainDataset(datasets)[source]#
连接多个
IterableDataset的数据集。此类有助于组合不同的现有数据集流。连接操作是即时进行的,因此使用此类连接大规模数据集将是高效的。
- 参数:
datasets (iterable of IterableDataset) – 要链接在一起的数据集
- class torch.utils.data.Subset(dataset, indices)[source]#
在指定索引处的数据集子集。
- 参数:
dataset (Dataset) – 整个数据集
indices (sequence) – 在整个集合中选择用于子集的索引
- torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[source]#
处理批次内集合类型元素的通用 collate 函数。
该函数还开放了函数注册机制来处理特定的元素类型。default_collate_fn_map 为 tensor、numpy 数组、数字和字符串提供了默认的 collate 函数。
- 参数:
示例
>>> def collate_tensor_fn(batch, *, collate_fn_map): ... # Extend this function to handle batch of tensors ... return torch.stack(batch, 0) >>> def custom_collate(batch): ... collate_map = {torch.Tensor: collate_tensor_fn} ... return collate(batch, collate_fn_map=collate_map) >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
注意
每个 collate 函数都需要一个用于批次的 positional 参数和一个用于 collate 函数字典的 keyword 参数,如 collate_fn_map。
- torch.utils.data.default_collate(batch)[source]#
接收一个数据批次,并将批次内的元素放入一个具有额外批次大小外层维度的 tensor 中。
确切的输出类型可以是
torch.Tensor、torch.Tensor的 Sequence、torch.Tensor的 Collection,或者根据输入类型保持不变。当DataLoader中定义了 batch_size 或 batch_sampler 时,它被用作默认的 collation 函数。以下是一般输入类型(基于批次内元素的类型)到输出类型的映射
torch.Tensor->torch.Tensor(添加了批次大小的外层维度)NumPy 数组 ->
torch.Tensorfloat ->
torch.Tensorint ->
torch.Tensorstr -> str (保持不变)
bytes -> bytes (保持不变)
Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]
NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
- 参数:
batch – 要 collate 的单个批次
示例
>>> # Example with a batch of `int`s: >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: >>> default_collate(["a", "b", "c"]) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: >>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: >>> Point = namedtuple("Point", ["x", "y"]) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: >>> default_collate([(0, 1), (2, 3)]) [tensor([0, 2]), tensor([1, 3])] >>> # Example with `List` inside the batch: >>> default_collate([[0, 1], [2, 3]]) [tensor([0, 2]), tensor([1, 3])] >>> # Two options to extend `default_collate` to handle specific type >>> # Option 1: Write custom collate function and invoke `default_collate` >>> def custom_collate(batch): ... elem = batch[0] ... if isinstance(elem, CustomType): # Some custom condition ... return ... ... else: # Fall back to `default_collate` ... return default_collate(batch) >>> # Option 2: In-place modify `default_collate_fn_map` >>> def collate_customtype_fn(batch, *, collate_fn_map=None): ... return ... >>> default_collate_fn_map.update(CustomType, collate_customtype_fn) >>> default_collate(batch) # Handle `CustomType` automatically
- torch.utils.data.default_convert(data)[source]#
将每个 NumPy 数组元素转换为
torch.Tensor。如果输入是 Sequence、Collection 或 Mapping,它会尝试将内部的每个元素转换为
torch.Tensor。如果输入不是 NumPy 数组,则保持不变。当DataLoader中既没有定义 batch_sampler 也没有定义 batch_size 时,它被用作默认的 collation 函数。一般输入类型到输出类型的映射与
default_collate()类似。有关更多详细信息,请参阅其描述。- 参数:
data – 要转换的单个数据点
示例
>>> # Example with `int` >>> default_convert(0) 0 >>> # Example with NumPy array >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple >>> Point = namedtuple("Point", ["x", "y"]) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) Point(x=tensor(0), y=tensor(0)) >>> # Example with List >>> default_convert([np.array([0, 1]), np.array([2, 3])]) [tensor([0, 1]), tensor([2, 3])]
- torch.utils.data.get_worker_info()[源代码]#
返回当前
DataLoader迭代器工作进程的信息。在工作进程中调用时,它将返回一个保证具有以下属性的对象:
id:当前工作进程 ID。num_workers:工作进程总数。seed:为当前工作进程设置的随机种子。此值由主进程的 RNG 和工作进程 ID 确定。有关更多详细信息,请参阅DataLoader的文档。dataset:**此**进程中数据集对象的副本。请注意,在不同进程中,它将是一个不同于主进程中对象的对象。
在主进程中调用时,将返回
None。注意
当用于传递给
DataLoader的worker_init_fn时,此方法可用于以不同方式设置每个工作进程,例如,使用worker_id配置dataset对象以仅读取分片数据集的特定部分,或使用seed来为数据集代码中使用的其他库设置种子。- 返回类型:
WorkerInfo | None
- torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[源代码]#
将数据集随机分割成给定长度的不重叠的新数据集。
如果给定一个总和为 1 的分数列表,则长度将自动计算为每个提供的分数的 floor(frac * len(dataset))。
计算长度后,如果有任何余数,将以轮询方式将 1 个计数分配给长度,直到没有余数为止。
可选地固定生成器以获得可复现的结果,例如:
示例
>>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
- class torch.utils.data.Sampler[源代码]#
所有 Sampler 的基类。
每个 Sampler 子类都必须提供一个
__iter__()方法,提供一种迭代数据集元素索引或索引列表(批次)的方法,并且可以提供一个返回迭代器长度的__len__()方法。示例
>>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()
注意
__len__()方法不是DataLoader严格必需的,但在涉及DataLoader长度的任何计算中都会期望它。
- class torch.utils.data.SequentialSampler(data_source)[源代码]#
按顺序采样元素,始终按相同顺序。
- 参数:
data_source (Sized) – 要从中采样的数据源。必须实现 __len__。
- class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[源代码]#
随机采样元素。如果无放回,则从打乱的数据集中采样。
如果带放回,则用户可以指定
num_samples来进行抽取。
- class torch.utils.data.SubsetRandomSampler(indices, generator=None)[源代码]#
从给定的索引列表中随机采样元素,无放回。
- 参数:
indices (sequence) – 索引序列
generator (Generator) – 用于采样的生成器。
- class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[源代码]#
根据给定的概率(权重)从
[0,..,len(weights)-1]中采样元素。- 参数:
示例
>>> list( ... WeightedRandomSampler( ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True ... ) ... ) [4, 4, 1, 4, 5] >>> list( ... WeightedRandomSampler( ... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False ... ) ... ) [0, 1, 4, 3, 2]
- class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[源代码]#
封装另一个采样器以生成索引的迷你批次。
- 参数:
示例
>>> list( ... BatchSampler( ... SequentialSampler(range(10)), batch_size=3, drop_last=False ... ) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list( ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[源代码]#
限制数据加载到数据集子集的采样器。
它尤其适用于与
torch.nn.parallel.DistributedDataParallel结合使用。在这种情况下,每个进程都可以将DistributedSampler实例作为DataLoader采样器传递,并加载原始数据集的一个独占子集。注意
假设数据集大小恒定,并且其任何实例始终以相同的顺序返回相同的元素。
- 参数:
dataset (Dataset) – 用于采样的数据集。
num_replicas (int, optional) – 参与分布式训练的进程数。默认情况下,
world_size从当前分布式组中检索。rank (int, optional) – 当前进程在
num_replicas中的排名。默认情况下,rank从当前分布式组中检索。shuffle (bool, optional) – 如果为
True(默认),则采样器将打乱索引。seed (int, optional) – 如果
shuffle=True,则用于打乱采样器的随机种子。此数字在分布式组中的所有进程中应相同。默认值:0。drop_last (bool, optional) – 如果为
True,则采样器将丢弃数据的尾部,使其能被副本数量整除。如果为False,则采样器将添加额外的索引,使数据能被副本数量整除。默认值:False。
警告
在分布式模式下,在每个 epoch 开始时,**在**创建
DataLoader迭代器**之前**调用set_epoch()方法对于正确实现多 epoch 打乱至关重要。否则,将始终使用相同的排序。示例
>>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)