评价此页

DataParallel#

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[source]#

在模块级别实现数据并行。

这个容器通过在批次维度上分块来将输入分割到指定的设备上,从而实现给定 module 的并行应用(其他对象将每个设备复制一次)。在前向传播中,模块会在每个设备上复制,每个副本处理一部分输入。在后向传播过程中,每个副本的梯度将被累加到原始模块中。

批次大小应大于使用的 GPU 数量。

警告

建议使用 DistributedDataParallel 进行多 GPU 训练,而不是此类,即使只有一个节点。请参见: 使用 nn.parallel.DistributedDataParallel 而不是 multiprocessing 或 nn.DataParallel分布式数据并行

允许将任意位置参数和关键字参数传递给 DataParallel,但某些类型会得到特殊处理。张量将在指定的维度(默认为 0)上被**散布**。元组、列表和字典类型将被浅拷贝。其他类型将在不同线程之间共享,如果它们在模型的正向传播中被写入,可能会损坏。

并行化的 module 在运行此 DataParallel 模块之前,必须将其参数和缓冲区放在 device_ids[0] 上。

警告

在每次正向传播中,module 会在每个设备上**复制**,因此在 forward 中对运行模块的任何更新都将丢失。例如,如果 module 有一个在每次 forward 中递增的计数器属性,它将始终保持初始值,因为更新是在副本上完成的,而这些副本在 forward 之后就会被销毁。然而,DataParallel 保证 device[0] 上的副本的参数和缓冲区将与基础并行化的 module 共享存储。因此,对 device[0] 上的参数或缓冲区的**原地**更新将被记录。例如,BatchNorm2dspectral_norm() 依赖于此行为来更新缓冲区。

警告

module 及其子模块上定义的正向和反向钩子将被调用 len(device_ids) 次,每次使用位于特定设备上的输入。特别地,仅保证钩子相对于对应设备上的操作以正确的顺序执行。例如,不能保证通过 register_forward_pre_hook() 设置的钩子会在 所有 len(device_ids)forward() 调用之前执行,但保证每个这样的钩子会在对应设备上对应的 forward() 调用之前执行。

警告

moduleforward() 中返回一个标量(即 0 维张量)时,此包装器将返回一个长度等于数据并行使用的设备数量的向量,其中包含每个设备的结果。

注意

在使用 Module(包装在 DataParallel 中)的 pack sequence -> recurrent network -> unpack sequence 模式时存在一个细微之处。详情请参阅 FAQ 中的 我的循环网络无法与数据并行配合使用 部分。

参数
  • module (Module) – 要并行化的模块

  • device_ids (list of int or torch.device) – CUDA 设备(默认:所有设备)

  • output_device (int or torch.device) – 输出的设备位置(默认:device_ids[0])

变量

module (Module) – 要并行化的模块

示例

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU