评价此页

torch.utils.data#

创建于: 2025年6月13日 | 最后更新于: 2025年6月13日

PyTorch 数据加载工具的核心是 torch.utils.data.DataLoader 类。它表示一个数据集上的 Python 可迭代对象,支持:

这些选项通过 DataLoader 的构造函数参数进行配置,其签名如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

下面的部分将详细介绍这些选项的效果和用法。

数据集类型#

DataLoader 构造函数最重要的参数是 dataset,它指定了要从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:

映射式数据集#

映射式数据集是实现了 __getitem__()__len__() 协议的数据集,它表示从索引/键(可能是非整数)到数据样本的映射。

例如,这样的数据集在通过 dataset[idx] 访问时,可以从磁盘上的文件夹中读取第 idx 张图像及其对应的标签。

更多详情请参见 Dataset

可迭代式数据集#

可迭代式数据集是 IterableDataset 的子类实例,它实现了 __iter__() 协议,并表示数据样本的可迭代对象。这种类型的数据集特别适用于随机读取成本高或不太可能发生,并且批次大小取决于所获取数据的情况。

例如,这样的数据集在调用 iter(dataset) 时,可以返回一个从数据库、远程服务器甚至实时生成的日志中读取数据的数据流。

更多详情请参见 IterableDataset

注意

当将 IterableDataset多进程数据加载 一起使用时。同一个数据集对象会在每个工作进程中复制,因此必须对这些副本进行不同的配置以避免重复数据。有关如何实现这一点,请参阅 IterableDataset 文档。

数据加载顺序与 Sampler#

对于 可迭代式数据集,数据加载顺序完全由用户定义的迭代器控制。这使得实现分块读取和动态批次大小(例如,每次产生一个批次样本)更加容易。

本节其余部分将讨论 映射式数据集 的情况。 torch.utils.data.Sampler 类用于指定数据加载时使用的索引/键的序列。它们表示数据集索引的可迭代对象。例如,在随机梯度下降(SGD)的常见情况下,一个 Sampler 可以随机排列索引列表并逐个产生,或者为 mini-batch SGD 产生少量索引。

一个顺序或随机的采样器将根据 DataLoadershuffle 参数自动构建。或者,用户可以使用 sampler 参数指定一个自定义的 Sampler 对象,该对象每次产生要获取的下一个索引/键。

一个每次产生一批索引的自定义 Sampler 可以作为 batch_sampler 参数传递。还可以通过 batch_sizedrop_last 参数启用自动批处理。有关这方面的更多信息,请参见 下一节

注意

Neither sampler nor batch_sampler is compatible with iterable-style datasets, since such datasets have no notion of a key or an index.Neither sampler nor batch_sampler 与可迭代式数据集不兼容,因为这类数据集没有键或索引的概念。

加载批次和非批次数据#

DataLoader 支持通过 batch_sizedrop_lastbatch_samplercollate_fn(它有一个默认函数)参数自动将单个获取的数据样本合并成批次。

自动批处理(默认)#

这是最常见的情况,对应于获取一个数据 minibatch 并将其合并成批次样本,即包含一个批次维度(通常是第一个)的张量(Tensors)。

batch_size(默认为 1)不为 None 时,数据加载器会产生批次样本而不是单个样本。batch_sizedrop_last 参数用于指定数据加载器如何获取数据集键的批次。对于映射式数据集,用户还可以选择指定 batch_sampler,它一次产生一个键列表。

注意

The batch_size and drop_last arguments essentially are used to construct a batch_sampler from sampler. For map-style datasets, the sampler is either provided by user or constructed based on the shuffle argument. For iterable-style datasets, the sampler is a dummy infinite one. See this section on more details on samplers.The batch_sizedrop_last 参数实际上用于从 sampler 构建一个 batch_sampler。对于映射式数据集,sampler 要么由用户提供,要么基于 shuffle 参数构建。对于可迭代式数据集,sampler 是一个假的无限采样器。有关采样器的更多详细信息,请参阅 本节

注意

When fetching from iterable-style datasets with multi-processing the drop_last argument drops the last non-full batch of each worker’s dataset replica.当使用 多进程可迭代式数据集 中获取数据时,drop_last 参数会丢弃每个工作进程数据集副本的最后一个非满批次。

After fetching a list of samples using the indices from sampler, the function passed as the collate_fn argument is used to collate lists of samples into batches.在从采样器获取索引列表后,通过 collate_fn 参数传递的函数用于将样本列表合并成批次。

In this case, loading from a map-style dataset is roughly equivalent with在这种情况,从映射式数据集加载大致等同于

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

and loading from an iterable-style dataset is roughly equivalent with并从可迭代式数据集加载大致等同于

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch. See this section on more about collate_fn.可以使用自定义的 collate_fn 来定制合并,例如,将序列数据填充到批次的最大长度。有关 collate_fn 的更多信息,请参见 本节

禁用自动批处理#

In certain cases, users may want to handle batching manually in dataset code, or simply load individual samples. For example, it could be cheaper to directly load batched data (e.g., bulk reads from a database or reading continuous chunks of memory), or the batch size is data dependent, or the program is designed to work on individual samples. Under these scenarios, it’s likely better to not use automatic batching (where collate_fn is used to collate the samples), but let the data loader directly return each member of the dataset object.在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者仅仅加载单个样本。例如,直接加载批次数据可能更便宜(例如,从数据库批量读取或读取连续内存块),或者批次大小依赖于数据,或者程序设计为处理单个样本。在这些场景下,最好不要使用自动批处理(其中 collate_fn 用于合并样本),而是让数据加载器直接返回 dataset 对象的每个成员。

When both batch_size and batch_sampler are None (default value for batch_sampler is already None), automatic batching is disabled. Each sample obtained from the dataset is processed with the function passed as the collate_fn argument.当 batch_sizebatch_sampler 都为 None 时(batch_sampler 的默认值已经是 None),自动批处理会被禁用。从 dataset 获取的每个样本都会通过作为 collate_fn 参数传递的函数进行处理。

When automatic batching is disabled, the default collate_fn simply converts NumPy arrays into PyTorch Tensors, and keeps everything else untouched.禁用自动批处理时,默认的 collate_fn 仅将 NumPy 数组转换为 PyTorch 张量,并保持其他所有内容不变。

In this case, loading from a map-style dataset is roughly equivalent with在这种情况,从映射式数据集加载大致等同于

for index in sampler:
    yield collate_fn(dataset[index])

and loading from an iterable-style dataset is roughly equivalent with并从可迭代式数据集加载大致等同于

for data in iter(dataset):
    yield collate_fn(data)

See this section on more about collate_fn.有关 collate_fn 的更多信息,请参见 本节

使用 collate_fn#

When automatic batching is enabled or disabled, the usage of collate_fn is slightly different.当启用或禁用自动批处理时,collate_fn 的用法略有不同。

When automatic batching is disabled, collate_fn is called with each individual data sample, and the output is yielded from the data loader iterator. In this case, the default collate_fn simply converts NumPy arrays in PyTorch tensors.禁用自动批处理时collate_fn 会与每个单独的数据样本一起调用,然后输出从数据加载器迭代器中产生。在这种情况下,默认的 collate_fn 仅将 NumPy 数组转换为 PyTorch 张量。

When automatic batching is enabled, collate_fn is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator. The rest of this section describes the behavior of the default collate_fn (default_collate()).启用自动批处理时collate_fn 会一次性接收一个数据样本列表。它需要将输入的样本合并成一个批次,以便从数据加载器迭代器中产生。本节其余部分将描述默认 collate_fndefault_collate())的行为。

For instance, if each data sample consists of a 3-channel image and an integral class label, i.e., each element of the dataset returns a tuple (image, class_index), the default collate_fn collates a list of such tuples into a single tuple of a batched image tensor and a batched class label Tensor. In particular, the default collate_fn has the following properties例如,如果每个数据样本由一个 3 通道图像和一个整数类标签组成,即数据集的每个元素返回一个元组 (image, class_index),那么默认的 collate_fn 将这样一个元组列表合并成一个包含批次图像张量和批次类标签张量的单个元组。特别是,默认的 collate_fn 具有以下特性:

  • It always prepends a new dimension as the batch dimension.它总是添加一个新的维度作为批次维度。

  • It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.它自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。

  • It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.它保留数据结构,例如,如果每个样本是字典,它会输出一个具有相同键的字典,但值为批次的张量(如果值无法转换为张量,则为列表)。对于 listtuplenamedtuple 等也是如此。

Users may use customized collate_fn to achieve custom batching, e.g., collating along a dimension other than the first, padding sequences of various lengths, or adding support for custom data types.用户可以使用自定义的 collate_fn 来实现自定义批处理,例如,沿第一个维度以外的维度进行合并,填充各种长度的序列,或添加对自定义数据类型的支持。

If you run into a situation where the outputs of DataLoader have dimensions or type that is different from your expectation, you may want to check your collate_fn.如果你遇到 DataLoader 的输出维度或类型与你的预期不同,你可能需要检查你的 collate_fn

单进程和多进程数据加载#

A DataLoader uses single-process data loading by default.默认情况下,DataLoader 使用单进程数据加载。

Within a Python process, the Global Interpreter Lock (GIL) prevents true fully parallelizing Python code across threads. To avoid blocking computation code with data loading, PyTorch provides an easy switch to perform multi-process data loading by simply setting the argument num_workers to a positive integer.在 Python 进程内部,全局解释器锁 (GIL) 阻止了跨线程对 Python 代码进行真正完全的并行化。为了避免数据加载阻塞计算代码,PyTorch 提供了一个简单的切换方法,只需将 num_workers 参数设置为正整数即可执行多进程数据加载。

单进程数据加载(默认)#

In this mode, data fetching is done in the same process a DataLoader is initialized. Therefore, data loading may block computing. However, this mode may be preferred when resource(s) used for sharing data among processes (e.g., shared memory, file descriptors) is limited, or when the entire dataset is small and can be loaded entirely in memory. Additionally, single-process loading often shows more readable error traces and thus is useful for debugging.在此模式下,数据获取在 DataLoader 初始化时所在的同一进程中进行。因此,数据加载可能会阻塞计算。但是,当用于进程间共享数据的资源(例如共享内存、文件描述符)有限时,或者当整个数据集很小并且可以完全加载到内存中时,可能更倾向于使用此模式。此外,单进程加载通常会显示更易读的错误跟踪,因此有助于调试。

多进程数据加载#

Setting the argument num_workers as a positive integer will turn on multi-process data loading with the specified number of loader worker processes.将 num_workers 参数设置为正整数将启用多进程数据加载,并使用指定数量的加载器工作进程。

警告

After several iterations, the loader worker processes will consume the same amount of CPU memory as the parent process for all Python objects in the parent process which are accessed from the worker processes. This can be problematic if the Dataset contains a lot of data (e.g., you are loading a very large list of filenames at Dataset construction time) and/or you are using a lot of workers (overall memory usage is number of workers * size of parent process). The simplest workaround is to replace Python objects with non-refcounted representations such as Pandas, Numpy or PyArrow objects. Check out issue #13246 for more details on why this occurs and example code for how to workaround these problems.经过几次迭代后,加载器工作进程将消耗与父进程相同的 CPU 内存,用于父进程中被工作进程访问的所有 Python 对象。如果 Dataset 包含大量数据(例如,在 Dataset 构建时加载非常大的文件名列表)和/或您使用了大量工作进程(总内存使用量为 工作进程数 * 父进程大小),这可能会出现问题。最简单的解决方法是用非引用计数的表示形式替换 Python 对象,例如 Pandas、Numpy 或 PyArrow 对象。请查看 issue #13246 以获取更多关于此问题发生原因的详细信息以及如何解决这些问题的示例代码。

In this mode, each time an iterator of a DataLoader is created (e.g., when you call enumerate(dataloader)), num_workers worker processes are created. At this point, the dataset, collate_fn, and worker_init_fn are passed to each worker, where they are used to initialize, and fetch data. This means that dataset access together with its internal IO, transforms (including collate_fn) runs in the worker process.在此模式下,每次创建 DataLoader 的迭代器时(例如,当您调用 enumerate(dataloader) 时),会创建 num_workers 个工作进程。此时,datasetcollate_fnworker_init_fn 会被传递给每个工作进程,并在那里用于初始化和获取数据。这意味着数据集访问及其内部 IO、转换(包括 collate_fn)都在工作进程中运行。

torch.utils.data.get_worker_info() 返回工作进程中的各种有用信息(包括工作进程 ID、数据集副本、初始种子等),在主进程中返回 None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数来单独配置每个数据集副本,并确定代码是否正在工作进程中运行。例如,这在分片数据集时特别有用。

对于映射式数据集,主进程使用 sampler 生成索引,并将它们发送给工作进程。因此,任何洗牌随机化都在主进程中完成,主进程通过分配要加载的索引来指导加载。

对于可迭代式数据集,由于每个工作进程都获得数据集对象的副本,因此简单的多进程加载通常会导致数据重复。使用 torch.utils.data.get_worker_info() 和/或 worker_init_fn,用户可以独立配置每个副本。(请参阅 IterableDataset 文档以了解如何实现此目的。)出于类似的原因,在多进程加载中,drop_last 参数会丢弃每个工作进程的可迭代式数据集副本的最后一个非满批次。

一旦达到迭代的末尾,或者迭代器被垃圾回收,工作进程将被关闭。

警告

It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing). Instead, we recommend using automatic memory pinning (i.e., setting pin_memory=True), which enables fast data transfer to CUDA-enabled GPUs.在多进程加载中,通常不建议返回 CUDA 张量,因为在多进程中使用 CUDA 和共享 CUDA 张量存在许多细微之处(请参阅 多进程中的 CUDA)。我们建议改用 自动内存固定(即设置 pin_memory=True),这可以实现到启用 CUDA 的 GPU 的快速数据传输。

平台特定行为#

Since workers rely on Python multiprocessing, worker launch behavior is different on Windows compared to Unix.由于工作进程依赖于 Python 的 multiprocessing 模块,因此与 Unix 相比,Windows 上的工作进程启动行为有所不同。

  • On Unix, fork() is the default multiprocessing start method. Using fork(), child workers typically can access the dataset and Python argument functions directly through the cloned address space.在 Unix 系统上,fork() 是默认的 multiprocessing 启动方法。使用 fork(),子工作进程通常可以通过克隆的地址空间直接访问 dataset 和 Python 参数函数。

  • On Windows or MacOS, spawn() is the default multiprocessing start method. Using spawn(), another interpreter is launched which runs your main script, followed by the internal worker function that receives the dataset, collate_fn and other arguments through pickle serialization.在 Windows 或 MacOS 上,spawn() 是默认的 multiprocessing 启动方法。使用 spawn(),会启动另一个解释器来运行您的主脚本,然后是内部工作进程函数,该函数通过 pickle 序列化接收 datasetcollate_fn 和其他参数。

This separate serialization means that you should take two steps to ensure you are compatible with Windows while using multi-process data loading这种单独的序列化意味着您需要采取两个步骤来确保在使用多进程数据加载时与 Windows 兼容:

  • Wrap most of you main script’s code within if __name__ == '__main__': block, to make sure it doesn’t run again (most likely generating error) when each worker process is launched. You can place your dataset and DataLoader instance creation logic here, as it doesn’t need to be re-executed in workers.将您主脚本的大部分代码包装在 if __name__ == '__main__': 块中,以确保在启动每个工作进程时不会再次运行(很可能导致错误)。您可以在此处放置数据集和 DataLoader 实例创建逻辑,因为这些逻辑不需要在工作进程中重新执行。

  • Make sure that any custom collate_fn, worker_init_fn or dataset code is declared as top level definitions, outside of the __main__ check. This ensures that they are available in worker processes. (this is needed since functions are pickled as references only, not bytecode.)确保任何自定义的 collate_fnworker_init_fndataset 代码在 __main__ 检查之外被声明为顶级定义。这确保它们在工作进程中可用。(这是必需的,因为函数仅作为引用被 pickle,而不是作为 bytecode。)

多进程数据加载中的随机性#

By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily) or a specified generator. However, seeds for other libraries may be duplicated upon initializing workers, causing each worker to return identical random numbers. (See this section in FAQ.).默认情况下,每个工作进程的 PyTorch 种子将设置为 base_seed + worker_id,其中 base_seed 是由主进程使用其 RNG 生成的长整数(从而强制消耗一个 RNG 状态)或指定的 generator。然而,其他库的种子在初始化工作进程时可能会重复,导致每个工作进程返回相同的随机数。(请参阅 FAQ 中的 本节)。

In worker_init_fn, you may access the PyTorch seed set for each worker with either torch.utils.data.get_worker_info().seed or torch.initial_seed(), and use it to seed other libraries before data loading.在 worker_init_fn 中,您可以通过 torch.utils.data.get_worker_info().seedtorch.initial_seed() 访问为每个工作进程设置的 PyTorch 种子,并使用它在数据加载之前为其他库设置种子。

内存固定#

Host to GPU copies are much faster when they originate from pinned (page-locked) memory. See Use pinned memory buffers for more details on when and how to use pinned memory generally.当主机到 GPU 的复制源自固定(分页锁定)内存时,速度会快得多。有关何时以及如何使用固定内存的更多详细信息,请参阅 使用固定内存缓冲区

For data loading, passing pin_memory=True to a DataLoader will automatically put the fetched data Tensors in pinned memory, and thus enables faster data transfer to CUDA-enabled GPUs.对于数据加载,将 pin_memory=True 传递给 DataLoader 将自动将获取的数据张量放入固定内存中,从而实现到启用 CUDA 的 GPU 的快速数据传输。

The default memory pinning logic only recognizes Tensors and maps and iterables containing Tensors. By default, if the pinning logic sees a batch that is a custom type (which will occur if you have a collate_fn that returns a custom batch type), or if each element of your batch is a custom type, the pinning logic will not recognize them, and it will return that batch (or those elements) without pinning the memory. To enable memory pinning for custom batch or data type(s), define a pin_memory() method on your custom type(s).默认的内存固定逻辑只识别包含张量的张量、映射和可迭代对象。默认情况下,如果固定逻辑看到一个自定义类型的批次(当您有一个返回自定义批次类型的 collate_fn 时会发生这种情况),或者如果您的批次中的每个元素都是自定义类型,那么固定逻辑将无法识别它们,并且会返回该批次(或那些元素)而不固定内存。要为自定义批次或数据类型启用内存固定,请在您的自定义类型上定义一个 pin_memory() 方法。

See the example below.请参阅下面的示例。

示例

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='', in_order=True)[source]#

Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.数据加载器结合了数据集和采样器,并提供给定数据集的可迭代对象。

The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning. DataLoader 支持映射式和可迭代式数据集,支持单进程或多进程加载、自定义加载顺序以及可选的自动批处理(合并)和内存固定。

See torch.utils.data documentation page for more details.有关更多详细信息,请参阅 torch.utils.data 文档页面。

参数
  • dataset (Dataset) – dataset from which to load the data. **dataset**(Dataset)– 要从中加载数据的的数据集。

  • batch_size (int, optional) – how many samples per batch to load (default: 1). **batch_size**(int, optional)– 每个批次要加载的样本数(默认为 1)。

  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False). **shuffle**(bool, optional)– 设置为 True 可在每个 epoch 中重新洗牌数据(默认为 False)。

  • sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified. **sampler**(SamplerIterable, optional)– 定义从数据集中抽取样本的策略。可以是任何实现了 __len__Iterable。如果指定了,则 shuffle 必须不被指定。

  • batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. **batch_sampler**(SamplerIterable, optional)– 类似于 sampler,但一次返回一个索引批次。与 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) **num_workers**(int, optional)– 用于数据加载的子进程数量。 0 表示数据将在主进程中加载。(默认为 0

  • collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. **collate_fn**(Callable, optional)– 将样本列表合并以形成张量(Tensors)的 mini-batch。当从映射式数据集进行批处理加载时使用。

  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below. **pin_memory**(bool, optional)– 如果为 True,数据加载器将在返回张量之前将其复制到设备/CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的 collate_fn 返回一个自定义类型的批次,请参阅下面的示例。

  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) **drop_last**(bool, optional)– 设置为 True 以丢弃最后一个不完整的批次,如果数据集大小不能被批次大小整除。如果为 False 且数据集大小不能被批次大小整除,则最后一个批次会更小。(默认为 False

  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0) **timeout**(numeric, optional)– 如果为正数,则为从工作进程收集批次的超时值。应始终为非负数。(默认为 0

  • worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None) **worker_init_fn**(Callable, optional)– 如果不为 None,则将在每个工作子进程中调用此函数,并将工作进程 ID([0, num_workers - 1] 中的整数)作为输入,在设置种子之后、加载数据之前调用。(默认为 None

  • multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context # noqa: D401 of your operating system will be used. (default: None) **multiprocessing_context**(strmultiprocessing.context.BaseContext, optional)– 如果为 None,则将使用您操作系统默认的 多进程上下文 # noqa: D401。(默认为 None

  • generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None) **generator**(torch.Generator, optional)– 如果不为 None,则 RandomSampler 将使用此 RNG 来生成随机索引,并由多进程为工作进程生成 base_seed。(默认为 None

  • prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2). **prefetch_factor**(int, optional, keyword-only arg)– 每个工作进程预加载的批次数。 2 表示所有工作进程总共预加载 2 * num_workers 个批次。(默认值取决于为 num_workers 设置的值。如果 num_workers=0,则默认值为 None。否则,如果 num_workers > 0,则默认值为 2)。

  • persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False) **persistent_workers**(bool, optional)– 如果为 True,数据加载器不会在数据集被消耗一次后关闭工作进程。这允许保持工作进程的 Dataset 实例存活。(默认为 False

  • pin_memory_device (str, optional) – Deprecated, the current accelerator will be used as the device if pin_memory=True. **pin_memory_device**(str, optional)– 已弃用,如果 pin_memory=True,则将使用当前 加速器 作为设备。

  • in_order (bool, optional) – If False, the data loader will not enforce that batches are returned in a first-in, first-out order. Only applies when num_workers > 0. (default: True) **in_order**(bool, optional)– 如果为 False,数据加载器将不强制按先入先出的顺序返回批次。仅在 num_workers > 0 时适用。(默认为 True

警告

If the spawn start method is used, worker_init_fn cannot be an unpicklable object, e.g., a lambda function. See Multiprocessing best practices on more details related to multiprocessing in PyTorch.如果使用 spawn 启动方法,则 worker_init_fn 不能是不可 picklable 的对象,例如 lambda 函数。有关 PyTorch 中多进程的更多详细信息,请参阅 多进程最佳实践

警告

len(dataloader) heuristic is based on the length of the sampler used. When dataset is an IterableDataset, it instead returns an estimate based on len(dataset) / batch_size, with proper rounding depending on drop_last, regardless of multi-process loading configurations. This represents the best guess PyTorch can make because PyTorch trusts user dataset code in correctly handling multi-process loading to avoid duplicate data. len(dataloader) 的启发式方法基于所使用的采样器的长度。当 datasetIterableDataset 时,它将根据 len(dataset) / batch_size 返回一个估计值,并根据 drop_last 进行适当舍入,而忽略多进程加载配置。这代表了 PyTorch 可以做出的最佳猜测,因为 PyTorch 相信用户 dataset 代码能够正确处理多进程加载以避免数据重复。

However, if sharding results in multiple workers having incomplete last batches, this estimate can still be inaccurate, because (1) an otherwise complete batch can be broken into multiple ones and (2) more than one batch worth of samples can be dropped when drop_last is set. Unfortunately, PyTorch can not detect such cases in general.然而,如果分片导致多个工作进程拥有不完整的最后一个批次,这个估计值仍然可能不准确,因为(1)一个原本完整的批次可能会被分成多个批次,并且(2)当设置 drop_last 时,可能会丢弃一个批次以上样本。不幸的是,PyTorch 通常无法检测到这种情况。

See Dataset Types for more details on these two types of datasets and how IterableDataset interacts with Multi-process data loading.有关这两种数据集类型的更多详细信息,以及 IterableDataset 如何与 多进程数据加载 交互,请参阅 数据集类型

警告

Setting in_order to False can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data.将 in_order 设置为 False 可能会损害可重现性,并在数据不平衡的情况下导致训练器接收到有偏斜的数据分布。

class torch.utils.data.Dataset[source]#

An abstract class representing a Dataset.一个表示 Dataset 的抽象类。

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.表示从键到数据样本的映射的所有数据集都应继承此类。所有子类都应覆盖 __getitem__(),以支持按给定键获取数据样本。子类还可以选择覆盖 __len__(),许多 Sampler 实现和 DataLoader 的默认选项都期望它返回数据集的大小。子类还可以选择实现 __getitems__(),以加速批次样本的加载。此方法接受一批样本的索引列表并返回样本列表。

注意

DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.默认情况下,DataLoader 构建一个生成整数索引的索引采样器。为了使其与具有非整数索引/键的映射式数据集一起工作,必须提供自定义采样器。

class torch.utils.data.IterableDataset[source]#

An iterable Dataset.一个可迭代的数据集。

All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.表示数据样本可迭代对象的所有数据集都应继承此类。这种形式的数据集在数据来自流时特别有用。

All subclasses should overwrite __iter__(), which would return an iterator of samples in this dataset.所有子类都应覆盖 __iter__(),它将返回此数据集中样本的迭代器。

When a subclass is used with DataLoader, each item in the dataset will be yielded from the DataLoader iterator. When num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s __iter__() method or the DataLoader ‘s worker_init_fn option to modify each copy’s behavior.当子类与 DataLoader 一起使用时,数据集中的每个项都将从 DataLoader 迭代器中产生。当 num_workers > 0 时,每个工作进程将拥有数据集对象的不同副本,因此通常希望独立配置每个副本以避免从工作进程返回重复数据。get_worker_info() 在工作进程中调用时,返回有关工作进程的信息。它可以在数据集的 __iter__() 方法或 DataLoaderworker_init_fn 选项中使用,以修改每个副本的行为。

Example 1: splitting workload across all workers in __iter__()示例 1:在 __iter__() 中跨所有工作进程分配工作负载

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

>>> # Multi-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

Example 2: splitting workload across all workers using worker_init_fn示例 2:使用 worker_init_fn 跨所有工作进程分配工作负载

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)[source]#

Dataset wrapping tensors.包装张量的数据集。

Each sample will be retrieved by indexing tensors along the first dimension.每个样本将通过沿第一个维度索引张量来检索。

参数

*tensors (Tensor) – tensors that have the same size of the first dimension. **\*tensors**(Tensor)– 具有相同第一个维度的张量。

class torch.utils.data.StackDataset(*args, **kwargs)[source]#

Dataset as a stacking of multiple datasets.将多个数据集堆叠起来的数据集。

This class is useful to assemble different parts of complex input data, given as datasets.此类有助于组合复杂的输入数据的不同部分,这些部分以数据集的形式提供。

示例

>>> images = ImageDataset()
>>> texts = TextDataset()
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {"image": images[0], "text": texts[0]}
参数
  • *args (Dataset) – Datasets for stacking returned as tuple. **\*args**(Dataset)– 作为元组返回的用于堆叠的数据集。

  • **kwargs (Dataset) – Datasets for stacking returned as dict. **\*\*kwargs**(Dataset)– 作为字典返回的用于堆叠的数据集。

class torch.utils.data.ConcatDataset(datasets)[source]#

Dataset as a concatenation of multiple datasets.将多个数据集串联起来的数据集。

This class is useful to assemble different existing datasets.此类有助于组合不同的现有数据集。

参数

datasets (sequence) – List of datasets to be concatenated **datasets**(sequence)– 要串联的数据集列表

class torch.utils.data.ChainDataset(datasets)[source]#

Dataset for chaining multiple IterableDataset s.用于串联多个 IterableDataset 的数据集。

This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient.此类有助于组合不同的现有数据集流。串联操作是即时完成的,因此使用此类串联大规模数据集将是高效的。

参数

datasets (iterable of IterableDataset) – datasets to be chained together **datasets**(IterableDatasetiterable)– 要串联在一起的数据集

class torch.utils.data.Subset(dataset, indices)[source]#

Subset of a dataset at specified indices.在指定索引处的数据集子集。

参数
  • dataset (Dataset) – The whole Dataset **dataset**(Dataset)– 整个数据集

  • indices (sequence) – Indices in the whole set selected for subset **indices**(sequence)– 在整个集合中选取的用于子集的索引

torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[source]#

General collate function that handles collection type of element within each batch.处理批次内元素集合类型的通用合并函数。

The function also opens function registry to deal with specific element types. default_collate_fn_map provides default collate functions for tensors, numpy arrays, numbers and strings.该函数还提供函数注册表来处理特定的元素类型。default_collate_fn_map 为张量、numpy 数组、数字和字符串提供了默认的合并函数。

参数
  • batch – a single batch to be collated **batch**– 要合并的单个批次

  • collate_fn_map (Optional[dict[Union[type, tuple[type, ...]], Callable]]) – Optional dictionary mapping from element type to the corresponding collate function. If the element type isn’t present in this dictionary, this function will go through each key of the dictionary in the insertion order to invoke the corresponding collate function if the element type is a subclass of the key. **collate_fn_map**(Optional[dict[Union[type, tuple[type, ...]], Callable]])– 可选字典,将元素类型映射到相应的合并函数。如果元素类型不在该字典中,此函数将按插入顺序遍历字典中的每个键,如果元素类型是键的子类,则调用相应的合并函数。

示例

>>> def collate_tensor_fn(batch, *, collate_fn_map):
...     # Extend this function to handle batch of tensors
...     return torch.stack(batch, 0)
>>> def custom_collate(batch):
...     collate_map = {torch.Tensor: collate_tensor_fn}
...     return collate(batch, collate_fn_map=collate_map)
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})

注意

Each collate function requires a positional argument for batch and a keyword argument for the dictionary of collate functions as collate_fn_map.每个合并函数都需要一个用于批次的 positional 参数和一个用于合并函数字典的 keyword 参数,即 collate_fn_map

torch.utils.data.default_collate(batch)[source]#

Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.接收一个数据批次,并将批次内的元素放入一个具有额外外层维度(批次大小)的张量中。

The exact output type can be a torch.Tensor, a Sequence of torch.Tensor, a Collection of torch.Tensor, or left unchanged, depending on the input type. This is used as the default function for collation when batch_size or batch_sampler is defined in DataLoader.确切的输出类型可以是 torch.Tensortorch.TensorSequencetorch.Tensor 的 Collection,或者保持不变,具体取决于输入类型。当在 DataLoader 中定义 batch_sizebatch_sampler 时,此函数用作默认的合并函数。

Here is the general input type (based on the type of the element within the batch) to output type mapping以下是通用的输入类型(基于批次中元素的类型)到输出类型的映射:

  • torch.Tensor -> torch.Tensor (with an added outer dimension batch size) torch.Tensor -> torch.Tensor(添加了外层批次大小维度)

  • NumPy Arrays -> torch.Tensor NumPy 数组 -> torch.Tensor

  • float -> torch.Tensor float -> torch.Tensor

  • int -> torch.Tensor int -> torch.Tensor

  • str -> str (unchanged) str -> str(不变)

  • bytes -> bytes (unchanged) bytes -> bytes(不变)

  • Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])] Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

  • NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …] NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

  • Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …] Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

参数

batch – a single batch to be collated **batch**– 要合并的单个批次

示例

>>> # Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s:
>>> default_collate(["a", "b", "c"])
['a', 'b', 'c']
>>> # Example with `Map` inside the batch:
>>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
>>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple("Point", ["x", "y"])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]
>>> # Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]
>>> # Two options to extend `default_collate` to handle specific type
>>> # Option 1: Write custom collate function and invoke `default_collate`
>>> def custom_collate(batch):
...     elem = batch[0]
...     if isinstance(elem, CustomType):  # Some custom condition
...         return ...
...     else:  # Fall back to `default_collate`
...         return default_collate(batch)
>>> # Option 2: In-place modify `default_collate_fn_map`
>>> def collate_customtype_fn(batch, *, collate_fn_map=None):
...     return ...
>>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
>>> default_collate(batch)  # Handle `CustomType` automatically
torch.utils.data.default_convert(data)[source]#

Convert each NumPy array element into a torch.Tensor.将每个 NumPy 数组元素转换为 torch.Tensor

If the input is a Sequence, Collection, or Mapping, it tries to convert each element inside to a torch.Tensor. If the input is not an NumPy array, it is left unchanged. This is used as the default function for collation when both batch_sampler and batch_size are NOT defined in DataLoader.如果输入是 SequenceCollectionMapping,它会尝试将内部的每个元素转换为 torch.Tensor。如果输入不是 NumPy 数组,则保持不变。当 DataLoader 中没有定义 batch_samplerbatch_size 时,此函数用作默认的合并函数。

The general input type to output type mapping is similar to that of default_collate(). See the description there for more details.通用的输入类型到输出类型的映射与 default_collate() 类似。有关更多详细信息,请参阅那里的描述。

参数

data – a single data point to be converted **data**– 要转换的单个数据点

示例

>>> # Example with `int`
>>> default_convert(0)
0
>>> # Example with NumPy array
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # Example with NamedTuple
>>> Point = namedtuple("Point", ["x", "y"])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
Point(x=tensor(0), y=tensor(0))
>>> # Example with List
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
[tensor([0, 1]), tensor([2, 3])]
torch.utils.data.get_worker_info()[source]#

Returns the information about the current DataLoader iterator worker process.返回当前 DataLoader 迭代器工作进程的信息。

When called in a worker, this returns an object guaranteed to have the following attributes在工作进程中调用时,这会返回一个保证具有以下属性的对象:

  • id: the current worker id. id:当前工作进程 ID。

  • num_workers: the total number of workers. num_workers:工作进程总数。

  • seed: the random seed set for the current worker. This value is determined by main process RNG and the worker id. See DataLoader’s documentation for more details. seed:为当前工作进程设置的随机种子。此值由主进程 RNG 和工作进程 ID 确定。有关更多详细信息,请参阅 DataLoader 的文档。

  • dataset: the copy of the dataset object in this process. Note that this will be a different object in a different process than the one in the main process. dataset进程中数据集对象的副本。请注意,这在主进程的对象中将是不同进程中的不同对象。

When called in the main process, this returns None.在主进程中调用时,返回 None

注意

When used in a worker_init_fn passed over to DataLoader, this method can be useful to set up each worker process differently, for instance, using worker_id to configure the dataset object to only read a specific fraction of a sharded dataset, or use seed to seed other libraries used in dataset code.当在传递给 DataLoaderworker_init_fn 中使用时,此方法可用于以不同方式设置每个工作进程,例如,使用 worker_id 配置 dataset 对象以仅读取分片数据集的特定部分,或使用 seed 为数据集代码中使用的其他库设置种子。

返回类型

Optional[WorkerInfo]

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[source]#

Randomly split a dataset into non-overlapping new datasets of given lengths.将数据集随机分割成给定长度的非重叠新数据集。

If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.如果给出加起来等于 1 的分数列表,则长度将自动计算为每个提供的分数的 floor(frac \* len(dataset))。

After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left.计算长度后,如果存在任何余数,则将以循环方式将 1 个计数分配给长度,直到没有余数为止。

Optionally fix the generator for reproducible results, e.g.可选地固定生成器以获得可重现的结果,例如

示例

>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
参数
  • dataset (Dataset) – Dataset to be split **dataset**(Dataset)– 要分割的数据集

  • lengths (sequence) – lengths or fractions of splits to be produced **lengths**(sequence)– 要生成的分割的长度或分数

  • generator (Generator) – Generator used for the random permutation. **generator**(Generator)– 用于随机排列的生成器。

返回类型

list[torch.utils.data.dataset.Subset[~_T]]

class torch.utils.data.Sampler(data_source=None)[source]#

Base class for all Samplers.所有采样器的基类。

Every Sampler subclass has to provide an __iter__() method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a __len__() method that returns the length of the returned iterators.每个 Sampler 子类都必须提供一个 __iter__() 方法,提供一种迭代数据集元素索引或索引列表(批次)的方法,并且可以提供一个返回迭代器长度的 __len__() 方法。

参数

data_source (Dataset) – This argument is not used and will be removed in 2.2.0. You may still have custom implementation that utilizes it. **data_source**(Dataset)– 此参数未使用,将在 2.2.0 中删除。您可能仍有自定义实现使用它。

示例

>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

注意

The __len__() method isn’t strictly required by DataLoader, but is expected in any calculation involving the length of a DataLoader. __len__() 方法不是 DataLoader 严格必需的,但在涉及 DataLoader 长度的任何计算中都期望存在。

class torch.utils.data.SequentialSampler(data_source)[source]#

Samples elements sequentially, always in the same order.按顺序采样元素,始终以相同的顺序。

参数

data_source (Dataset) – dataset to sample from **data_source**(Dataset)– 要采样的个数据集

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[source]#

Samples elements randomly. If without replacement, then sample from a shuffled dataset.随机采样元素。如果不放回,则从洗牌的数据集中采样。

If with replacement, then user can specify num_samples to draw.如果放回,则用户可以指定 num_samples 进行抽取。

参数
  • data_source (Dataset) – dataset to sample from **data_source**(Dataset)– 要采样的个数据集

  • replacement (bool) – samples are drawn on-demand with replacement if True, default=``False`` **replacement**(bool)– 如果为 True,则样本是按需放回抽样的,默认为 ``False``。

  • num_samples (int) – number of samples to draw, default=`len(dataset)`. **num_samples**(int)– 要抽取的样本数,默认为 `len(dataset)`。

  • generator (Generator) – Generator used in sampling. **generator**(Generator)– 在采样中使用的生成器。

class torch.utils.data.SubsetRandomSampler(indices, generator=None)[source]#

Samples elements randomly from a given list of indices, without replacement.从给定的索引列表中随机采样元素,不放回。

参数
  • indices (sequence) – a sequence of indices **indices**(sequence)– 索引序列

  • generator (Generator) – Generator used in sampling. **generator**(Generator)– 在采样中使用的生成器。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[source]#

Samples elements from [0,..,len(weights)-1] with given probabilities (weights).根据给定的概率(权重)从 [0,..,len(weights)-1] 中采样元素。

参数
  • weights (sequence) – a sequence of weights, not necessary summing up to one **weights**(sequence)– 权重序列,不一定加起来等于一

  • num_samples (int) – number of samples to draw **num_samples**(int)– 要抽取的样本数

  • replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row. **replacement**(bool)– 如果为 True,则样本是放回抽样的。如果不是,则是不放回抽样,这意味着当一行样本索引被抽取时,该行不能再次被抽取。

  • generator (Generator) – Generator used in sampling. **generator**(Generator)– 在采样中使用的生成器。

示例

>>> list(
...     WeightedRandomSampler(
...         [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True
...     )
... )
[4, 4, 1, 4, 5]
>>> list(
...     WeightedRandomSampler(
...         [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False
...     )
... )
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source]#

Wraps another sampler to yield a mini-batch of indices.包装另一个采样器以产生一个索引的 mini-batch。

参数
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object **sampler**(SamplerIterable)– 基本采样器。可以是任何可迭代对象

  • batch_size (int) – Size of mini-batch. **batch_size**(int)– mini-batch 的大小。

  • drop_last (bool) – If True,sampler 将会丢弃最后一个批次(batch),如果它的尺寸小于 batch_size

示例

>>> list(
...     BatchSampler(
...         SequentialSampler(range(10)), batch_size=3, drop_last=False
...     )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(
...     BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source]#

限制数据加载到数据集子集的 Sampler。

它尤其适用于与 torch.nn.parallel.DistributedDataParallel 结合使用。在这种情况下,每个进程都可以将一个 DistributedSampler 实例作为 DataLoader 的 sampler,并加载原始数据集的、仅属于它的一个子集。

注意

假定数据集的大小是恒定的,并且它的任何实例始终以相同的顺序返回相同的元素。

参数
  • dataset (Dataset) – 用于采样的 Dataset。

  • num_replicas (int, optional) – 参与分布式训练的进程数。默认情况下,从当前分布式组检索 world_size

  • rank (int, optional) – 当前进程在 num_replicas 中的 rank。默认情况下,从当前分布式组检索 rank

  • shuffle (bool, optional) – 如果为 True (默认),sampler 将会打乱索引。

  • seed (int, optional) – 如果 shuffle=True,用于打乱 sampler 的随机种子。此数字在分布式组的所有进程中应保持一致。默认值:0

  • drop_last (bool, optional) – 如果为 True,则 sampler 将丢弃数据的尾部,使其能被副本数整除。如果为 False,sampler 将添加额外的索引以使数据能被副本数整除。默认值:False

警告

在分布式模式下,在每个 epoch 开始时、创建 DataLoader 迭代器 **之前** 调用 set_epoch() 方法对于确保多个 epoch 之间正确打乱顺序是必要的。否则,将始终使用相同的顺序。

示例

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)