DistributedDataParallel#
- class torch.nn.parallel.DistributedDataParallel(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, init_sync=True, process_group=None, bucket_cap_mb=None, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False, delay_all_reduce_named_params=None, param_to_hook_all_reduce=None, mixed_precision=None, device_mesh=None, skip_all_reduce_unused_params=False, bucket_cap_mb_list=None)[源码]#
在模块级别实现基于
torch.distributed的分布式数据并行。该容器通过在每个模型副本之间同步梯度来提供数据并行性。同步所涉及的设备由输入的
process_group指定,默认情况下为整个 world。请注意,DistributedDataParallel不会对参与的 GPU 之间的输入进行分块或分片;用户负责定义如何操作,例如通过使用DistributedSampler。另请参阅:基础知识 和 使用 nn.parallel.DistributedDataParallel 代替 multiprocessing 或 nn.DataParallel。适用于
torch.nn.DataParallel的相同输入约束也适用于此。创建此类要求
torch.distributed已通过调用torch.distributed.init_process_group()完成初始化。事实证明,对于单机多卡数据并行训练,
DistributedDataParallel明显快于torch.nn.DataParallel。要在具有 N 个 GPU 的主机上使用
DistributedDataParallel,您应该启动N个进程,确保每个进程专门处理 0 到 N-1 中的单个 GPU。这可以通过为每个进程设置CUDA_VISIBLE_DEVICES或通过为 GPU 调用以下 API 来实现:>>> torch.cuda.set_device(i)
或者调用针对 加速器 的统一 API:
>>> torch.accelerator.set_device_index(i)
其中 i 取值范围为 0 到 N-1。在每个进程中,您应该参考以下内容来构造此模块:
>>> if torch.accelerator.is_available(): >>> device_type = torch.accelerator.current_accelerator().type >>> vendor_backend = torch.distributed.get_default_backend_for_device(device_type) >>> >>> torch.distributed.init_process_group( >>> backend=vendor_backend, world_size=N, init_method='...' >>> ) >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
或者,您可以使用最新的 API 进行初始化:
>>> torch.distributed.init_process_group(device_id=i)
为了在每个节点上启动多个进程,您可以使用
torch.distributed.launch或torch.multiprocessing.spawn。注意
请参阅 PyTorch 分布式概述,了解与分布式训练相关的所有功能的简要介绍。
注意
DistributedDataParallel可以与torch.distributed.optim.ZeroRedundancyOptimizer结合使用,以减少每个 rank 的优化器状态内存占用。请参阅 ZeroRedundancyOptimizer 教程 了解更多详情。注意
使用 GPU 时,
nccl后端是目前最快且高度推荐的后端。这适用于单节点和多节点分布式训练。注意
该模块还支持混合精度分布式训练。这意味着您的模型可以拥有不同类型的参数(如
fp16和fp32的混合类型),这些混合类型参数的梯度规约(gradient reduction)将可以正常工作。注意
如果您在一个进程上使用
torch.save来保存模块的检查点,并在其他进程上使用torch.load来恢复它,请确保为每个进程正确配置了map_location。如果没有map_location,torch.load会将模块恢复到保存该模块时所在的设备。注意
当模型在
M个节点上以batch=N进行训练时,如果损失是在 batch 中的实例间求和(而不是像往常那样求平均),那么与在单个节点上以batch=M*N训练的相同模型相比,梯度将小M倍(因为不同节点之间的梯度被取了平均)。当您希望获得与本地训练对应的数学等效训练过程时,应考虑到这一点。但在大多数情况下,您可以直接将 DistributedDataParallel 包装的模型、DataParallel 包装的模型和单 GPU 上的普通模型视为相同(例如,为等效的 batch size 使用相同的学习率)。注意
参数永远不会在进程之间广播。该模块对梯度执行 all-reduce 步骤,并假设所有进程中的优化器将以相同的方式修改它们。缓冲区(例如 BatchNorm 统计信息)在每次迭代中都会从 rank 0 进程中的模块广播到系统中的所有其他副本。
注意
如果您将 DistributedDataParallel 与 分布式 RPC 框架 结合使用,则应始终使用
torch.distributed.autograd.backward()来计算梯度,并使用torch.distributed.optim.DistributedOptimizer来优化参数。示例
>>> import torch.distributed.autograd as dist_autograd >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> import torch >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> import torch.distributed.rpc as rpc >>> from torch.distributed.rpc import RRef >>> >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2)) >>> ddp_model = DDP(my_model) >>> >>> # Setup optimizer >>> optimizer_params = [rref] >>> for param in ddp_model.parameters(): >>> optimizer_params.append(RRef(param)) >>> >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> optimizer_params, >>> lr=0.05, >>> ) >>> >>> with dist_autograd.context() as context_id: >>> pred = ddp_model(rref.to_here()) >>> loss = loss_func(pred, target) >>> dist_autograd.backward(context_id, [loss]) >>> dist_optim.step(context_id)
注意
DistributedDataParallel 目前对使用
torch.utils.checkpoint()的梯度检查点提供有限的支持。如果使用 use_reentrant=False(推荐)进行检查点操作,DDP 将按预期工作,没有任何限制。但是,如果使用 use_reentrant=True(默认)进行检查点操作,当模型中没有未使用参数且每一层最多被检查点化一次时,DDP 将按预期工作(确保不向 DDP 传递 find_unused_parameters=True)。我们目前不支持某一层被多次检查点化,或检查点模型中存在未使用参数的情况。注意
要让非 DDP 模型加载 DDP 模型的 state dict,需要在加载前应用
consume_prefix_in_state_dict_if_present()来去除 DDP state dict 中的前缀 “module.”。警告
构造函数、forward 方法和输出(或该模块输出的函数)的微分都是分布式同步点。考虑到不同进程可能正在执行不同的代码,请注意这一点。
警告
该模块假设所有参数在创建时都已在模型中注册。之后不应添加或删除任何参数。缓冲区也是如此。
警告
该模块假设每个分布式进程的模型中注册的所有参数顺序相同。模块本身将按照模型注册参数的相反顺序进行梯度
allreduce。换句话说,用户有责任确保每个分布式进程具有完全相同的模型,从而具有完全相同的参数注册顺序。警告
该模块允许具有非行优先连续步长(non-rowmajor-contiguous strides)的参数。例如,您的模型可能包含一些
torch.memory_format为torch.contiguous_format的参数,以及其他格式为torch.channels_last的参数。但是,不同进程中对应的参数必须具有相同的步长。警告
该模块不适用于
torch.autograd.grad()(即,它仅在梯度累积在参数的.grad属性中时才有效)。警告
如果您计划将此模块与
nccl后端或gloo后端(使用 Infiniband)以及使用多个 worker 的 DataLoader 一起使用,请将 multiprocessing 的启动方法更改为forkserver(仅限 Python 3)或spawn。遗憾的是,Gloo(使用 Infiniband)和 NCCL2 不是 fork 安全的,如果您不更改此设置,很可能会遇到死锁。警告
在用
DistributedDataParallel包装模型后,绝不应尝试更改模型的参数。因为在包装模型时,DistributedDataParallel的构造函数会在构造时在模型本身的所有参数上注册额外的梯度规约函数。如果您随后更改模型的参数,梯度规约函数将不再匹配正确的参数集。警告
将
DistributedDataParallel与 分布式 RPC 框架 结合使用是实验性的,并且可能会发生变化。- 参数:
module (Module) – 要并行化的模块
device_ids (list of int or torch.device) –
CUDA 设备。1) 对于单设备模块,
device_ids可以恰好包含一个设备 ID,表示该进程对应的输入模块所在的唯一 CUDA 设备。或者,device_ids也可以是None。2) 对于多设备模块和 CPU 模块,device_ids必须为None。当这两种情况的
device_ids均为None时,前向传播的输入数据和实际模块都必须放置在正确的设备上。(默认:None)output_device (int or torch.device) – 单设备 CUDA 模块输出的设备位置。对于多设备模块和 CPU 模块,它必须为
None,由模块本身决定输出位置。(默认:对于单设备模块为device_ids[0])broadcast_buffers (bool) – 在
forward函数开始时启用同步(广播)模块缓冲区的标志。(默认:True)init_sync (bool) – 是否在初始化期间进行同步以验证参数形状并广播参数和缓冲区。警告:如果将其设置为 False,用户需要自行确保所有 rank 上的权重相同。(默认:
True)process_group – 用于分布式数据 all-reduction 的进程组。如果为
None,将使用由torch.distributed.init_process_group()创建的默认进程组。(默认:None)bucket_cap_mb –
DistributedDataParallel将参数分桶到多个桶中,以便每个桶的梯度规约可以潜在地与反向计算重叠。bucket_cap_mb以 MebiBytes (MiB) 为单位控制桶大小。如果为None,将使用默认大小 25 MiB。(默认:None)find_unused_parameters (bool) – 从被包装模块的
forward函数返回值中包含的所有张量开始遍历 autograd 图。在该图中不接收梯度的参数会被预先标记为准备好进行规约(reduce)。此外,可能在被包装模块的forward函数中使用过但不是损失计算的一部分、因此也不会接收梯度的参数也会被预先标记为准备好规约。(默认:False)check_reduction – 此参数已弃用。
gradient_as_bucket_view (bool) – 当设置为
True时,梯度将是映射到allreduce通信桶不同偏移量的视图。这可以降低峰值内存使用量,节省的内存大小等于总梯度大小。此外,它避免了在梯度和allreduce通信桶之间进行复制的开销。当梯度为视图时,不能在梯度上调用detach_()。如果遇到此类错误,请参考torch/optim/optimizer.py中的zero_grad()函数作为解决方案。请注意,梯度在第一次迭代后将变为视图,因此应在第一次迭代后检查峰值内存节省情况。static_graph (bool) –
当设置为
True时,DDP 知道训练图是静态的。静态图意味着:1) 在整个训练循环期间,已使用和未使用的参数集不会改变;在这种情况下,用户是否设置find_unused_parameters = True并不重要。2) 图的训练方式在整个训练循环期间不会改变(意味着没有依赖于迭代次数的控制流)。当 static_graph 设置为True时,DDP 将支持过去无法支持的情况:1) 可重入反向传播(Reentrant backwards)。2) 多次激活检查点(Activation checkpointing)。3) 模型具有未使用参数时的激活检查点。4) 存在 forward 函数之外的模型参数。5) 存在未使用参数时可能会提高性能,因为当 static_graph 设置为True时,DDP 不会在每次迭代中搜索图来检测未使用参数。要检查是否可以将 static_graph 设置为True,一种方法是检查上一次模型训练结束时的 ddp 日志数据,如果ddp_logging_data.get("can_set_static_graph") == True,通常您也可以设置static_graph = True。- 示例:
>>> model_DDP = torch.nn.parallel.DistributedDataParallel(model) >>> # Training loop >>> ... >>> ddp_logging_data = model_DDP._get_ddp_logging_data() >>> static_graph = ddp_logging_data.get("can_set_static_graph")
delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter) – 命名参数列表,当
param_to_hook_all_reduce中指定的参数梯度就绪时,这些命名参数的 all reduce 将被延迟。DDP 的其他参数不适用于此参数中指定的命名参数,因为这些命名参数将被 DDP 规约器(reducer)忽略。param_to_hook_all_reduce (torch.nn.Parameter) – 用于挂钩
delay_all_reduce_named_params中指定参数延迟 all reduce 的参数。skip_all_reduce_unused_params – 当设置为 True 时,DDP 将跳过规约未使用参数。这要求在整个训练过程中,所有 rank 上的未使用参数保持一致。如果不满足此条件,可能会导致不同步并导致训练挂起。
- 变量:
module (Module) – 要并行化的模块。
示例
>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') >>> net = torch.nn.parallel.DistributedDataParallel(model)
- join(divide_by_initial_world_size=True, enable=True, throw_on_early_termination=False)[源码]#
用于在 DDP 中跨进程处理不均匀输入进行训练的上下文管理器。
此上下文管理器将跟踪已加入(joined)的 DDP 进程,并通过插入集合通信操作来“模拟”前向和反向传递,以匹配由未加入的 DDP 进程创建的操作。这将确保每个集合调用都有已加入 DDP 进程对应的调用,从而防止在跨进程训练不均匀输入时可能发生的挂起或错误。或者,如果将标志
throw_on_early_termination指定为True,则一旦某个 rank 耗尽输入,所有训练器都将抛出错误,从而允许根据应用程序逻辑捕获和处理这些错误。一旦所有 DDP 进程都已加入,上下文管理器将把最后加入进程对应的模型广播到所有进程,以确保所有进程的模型相同(这由 DDP 保证)。
要使用此功能实现跨进程不均匀输入训练,只需将此上下文管理器包装在您的训练循环周围即可。无需对模型或数据加载进行进一步修改。
警告
如果此上下文管理器包装的模型或训练循环具有额外的分布式集合操作,例如模型前向传递中的
SyncBatchNorm,则必须启用标志throw_on_early_termination。这是因为此上下文管理器无法感知非 DDP 集合通信。此标志将导致在任何一个 rank 耗尽输入时所有 rank 都会抛出异常,从而允许跨所有 rank 捕获并恢复这些错误。- 参数:
divide_by_initial_world_size (bool) – 如果为
True,将把梯度除以 DDP 训练启动时的初始world_size。如果为False,将计算有效 world size(尚未耗尽输入的 rank 数量),并在 allreduce 期间将梯度除以该值。设置divide_by_initial_world_size=True可确保每个输入样本(包括不均匀输入)在对全局梯度的贡献方面具有相同的权重。即使遇到不均匀输入,我们也始终将梯度除以初始world_size,从而实现这一点。如果将其设置为False,我们将梯度除以剩余的节点数。这确保了与在较小world_size上训练的对等性,尽管这也意味着不均匀输入对全局梯度的贡献更大。通常,对于训练作业的最后几个输入不均匀的情况,您可能希望将其设置为True。在输入数量差异很大的极端情况下,将其设置为False可能会提供更好的结果。enable (bool) – 是否启用不均匀输入检测。在已知参与进程之间输入均匀的情况下,传入
enable=False以禁用此功能。默认值为True。throw_on_early_termination (bool) – 当至少一个 rank 耗尽输入时,是抛出错误还是继续训练。如果为
True,将在第一个 rank 到达数据末尾时抛出异常。如果为False,将以较小的有效 world size 继续训练,直到所有 rank 都加入。请注意,如果指定了此标志,则会忽略divide_by_initial_world_size标志。默认值为False。
示例
>>> import torch >>> import torch.distributed as dist >>> import os >>> import torch.multiprocessing as mp >>> import torch.nn as nn >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> torch.cuda.set_device(rank) >>> model = nn.Linear(1, 1, bias=False).to(rank) >>> model = torch.nn.parallel.DistributedDataParallel( >>> model, device_ids=[rank], output_device=rank >>> ) >>> # Rank 1 gets one more input than rank 0. >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)] >>> with model.join(): >>> for _ in range(5): >>> for inp in inputs: >>> loss = model(inp).sum() >>> loss.backward() >>> # Without the join() API, the below synchronization will hang >>> # blocking for rank 1's allreduce to complete. >>> torch.cuda.synchronize(device=rank)
- join_hook(**kwargs)[源码]#
DDP join hook 通过在前向和反向传递中镜像通信来实现不均匀输入上的训练。
- 参数:
kwargs (dict) – 一个
dict,包含用于在运行时修改 join hook 行为的任何关键字参数;共享同一 join 上下文管理器的所有Joinable实例都将转发相同的kwargs值。
- 该钩子支持以下关键字参数:
- divide_by_initial_world_size (bool, 可选)
如果为
True,则梯度除以 DDP 启动时的初始 world size。如果为False,则梯度除以有效 world size(即未加入进程的数量),这意味着不均匀输入对全局梯度的贡献更大。通常,如果不均匀程度较小,应将其设置为True,但在极端情况下可以将其设置为False以获得可能更好的结果。默认值为True。
- no_sync()[源码]#
用于禁用 DDP 进程间梯度同步的上下文管理器。
在此上下文中,梯度将累积在模块变量上,这些变量随后将在退出该上下文后的第一个前向-反向传递中进行同步。
示例
>>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) >>> with ddp.no_sync(): >>> for input in inputs: >>> ddp(input).backward() # no synchronization, accumulate grads >>> ddp(another_input).backward() # synchronize grads
警告
前向传递应包含在上下文管理器中,否则梯度仍将被同步。
- register_comm_hook(state, hook)[源码]#
注册通信钩子,用于用户定义的跨多个 worker 的 DDP 梯度聚合。
该钩子对于研究人员尝试新想法非常有用。例如,该钩子可用于实现 GossipGrad 和梯度压缩等几种算法,这些算法在运行分布式 DataParallel 训练时涉及不同的参数同步通信策略。
- 参数:
state (object) –
传递给钩子以在训练过程中维护任何状态信息。示例包括梯度压缩中的错误反馈、GossipGrad 中下一步要通信的对等体等。
它由每个 worker 本地存储,并由该 worker 上的所有梯度张量共享。
hook (Callable) –
具有以下签名的可调用对象:
hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]一旦桶(bucket)准备就绪,就会调用此函数。钩子可以执行所需的任何处理,并返回一个 Future,表示任何异步工作(例如:allreduce)的完成。如果钩子不执行任何通信,它仍必须返回一个已完成的 Future。Future 应持有梯度桶张量的新值。一旦桶准备就绪,c10d 规约器将调用此钩子并使用 Future 返回的张量,并将梯度复制到各个参数。请注意,Future 的返回类型必须是单个张量。
我们还提供了一个名为
get_future的 API,用于检索与c10d.ProcessGroup.Work完成相关的 Future。get_future目前支持 NCCL,也支持 GLOO 和 MPI 上的大多数操作,但点对点操作(send/recv)除外。
警告
梯度桶的张量不会预先除以 world_size。用户有责任在执行 allreduce 等操作时除以 world_size。
警告
DDP 通信钩子只能注册一次,并且应在调用 backward 之前注册。
警告
钩子返回的 Future 对象应包含一个与梯度桶内的张量具有相同形状的单个张量。
警告
get_futureAPI 支持 NCCL,并部分支持 GLOO 和 MPI 后端(不支持点对点操作,如 send/recv),并将返回一个torch.futures.Future。- 示例:
下面是一个返回相同张量的空操作(noop)钩子示例。
>>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: >>> fut = torch.futures.Future() >>> fut.set_result(bucket.buffer()) >>> return fut >>> ddp.register_comm_hook(state=None, hook=noop)
- 示例:
下面是一个并行 SGD 算法示例,其中梯度在 allreduce 之前编码,然后在 allreduce 之后解码。
>>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: >>> encoded_tensor = encode(bucket.buffer()) # encode gradients >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future() >>> # Define the then callback to decode. >>> def decode(fut): >>> decoded_tensor = decode(fut.value()[0]) # decode gradients >>> return decoded_tensor >>> return fut.then(decode) >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)