DataParallel#
- class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[source]#
在模块级别实现数据并行。
此容器通过在批处理维度(batch dimension)上对输入进行分块(其他对象将被复制到每个设备一次),从而并行化给定
module的应用。在前向传播中,模块在每个设备上进行复制,每个副本处理一部分输入。在反向传播期间,来自每个副本的梯度会被求和汇总到原始模块中。批处理大小(batch size)应该大于所使用的 GPU 数量。
警告
建议使用
DistributedDataParallel而不是此类进行多 GPU 训练,即使是在单节点上也是如此。请参阅:使用 nn.parallel.DistributedDataParallel 代替 multiprocessing 或 nn.DataParallel 以及 分布式数据并行。允许将任意位置参数和关键字参数输入传递给 DataParallel,但某些类型会得到特殊处理。张量(tensors)将会在指定的维度(默认为 0)上进行分散(scattered)。元组、列表和字典类型将会被浅拷贝。其他类型将在不同线程间共享,如果在模型的
forward传播中对其进行写入,则可能会导致损坏。在运行此
DataParallel模块之前,被并行化的module的参数和缓冲区必须位于device_ids[0]上。警告
在每次前向传播中,
module都会在每个设备上被复制(replicated),因此在forward中对运行中的模块所做的任何更新都将会丢失。例如,如果module有一个计数器属性并在每次forward中递增,它将始终保持在初始值,因为更新是在副本上进行的,而副本在forward后会被销毁。然而,DataParallel保证device[0]上的副本其参数和缓冲区与基础并行化module共享存储。因此,对device[0]上的参数或缓冲区的原地(in-place)更新将会被记录。例如,BatchNorm2d和spectral_norm()就是依靠这种行为来更新缓冲区。警告
定义在
module及其子模块上的前向和反向钩子(hook)将被调用len(device_ids)次,每次输入都位于特定设备上。特别地,仅保证钩子会根据对应设备上的操作顺序正确执行。例如,不保证通过register_forward_pre_hook()设置的钩子会在 所有len(device_ids)次forward()调用之前执行,但保证每个此类钩子会在该设备对应的forward()调用之前执行。警告
当
module在forward()中返回标量(即 0 维张量)时,此包装器将返回一个向量,其长度等于数据并行中使用的设备数量,其中包含来自每个设备的结果。注意
在
DataParallel包装的Module中使用pack sequence -> recurrent network -> unpack sequence模式时存在一个微妙之处。详情请参阅 FAQ 中的 我的循环神经网络无法在数据并行下工作 一节。- 参数:
module (Module) – 要并行化的模块
device_ids (list of int or torch.device) – CUDA 设备(默认:所有设备)
output_device (int or torch.device) – 输出的设备位置(默认:device_ids[0])
- 变量:
module (Module) – 要并行化的模块
示例
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU