评价此页

分布式优化器#

创建于: 2021年3月1日 | 最后更新于: 2025年6月16日

警告

目前在使用 CUDA 张量时不支持分布式优化器

torch.distributed.optim 暴露了 DistributedOptimizer,它接收一个远程参数列表(RRef)并在参数所在的 worker 上本地运行优化器。分布式优化器可以使用任何本地优化器(请参阅 基类)在每个 worker 上应用梯度。

class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)[源代码]#

DistributedOptimizer 接收分布在 worker 上的参数的远程引用,并为每个参数在本地应用给定的优化器。

此类使用 get_gradients() 来检索特定参数的梯度。

step() 的并发调用(无论是来自同一客户端还是不同客户端)将在每个 worker 上被序列化——因为每个 worker 的优化器一次只能处理一组梯度。然而,不能保证完整的正向-反向-优化器序列为单个客户端一次执行。这意味着正在应用的梯度可能不对应于在给定 worker 上执行的最新正向传递。此外,跨 worker 也没有保证的顺序。

DistributedOptimizer 默认启用 TorchScript 来创建本地优化器,这样优化器更新就不会在多线程训练(例如分布式模型并行)的情况下被 Python 全局解释器锁(GIL)阻塞。此功能目前对大多数优化器都已启用。您也可以按照 PyTorch 教程中的 示例 为您自己的自定义优化器启用 TorchScript 支持。

参数:
  • optimizer_class (optim.Optimizer) – 在每个 worker 上实例化的优化器类。

  • params_rref (list[RRef]) – 要优化的本地或远程参数的 RRef 列表。

  • args – 要传递给每个 worker 上的优化器构造函数的参数。

  • kwargs – 要传递给每个 worker 上的优化器构造函数的参数。

示例:
>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>>
>>> with dist_autograd.context() as context_id:
>>>   # Forward pass.
>>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>>   loss = rref1.to_here() + rref2.to_here()
>>>
>>>   # Backward pass.
>>>   dist_autograd.backward(context_id, [loss.sum()])
>>>
>>>   # Optimizer.
>>>   dist_optim = DistributedOptimizer(
>>>      optim.SGD,
>>>      [rref1, rref2],
>>>      lr=0.05,
>>>   )
>>>   dist_optim.step(context_id)
step(context_id)[源代码]#

执行一次优化步骤。

这将调用包含要优化的参数的每个 worker 上的 torch.optim.Optimizer.step(),并阻塞直到所有 worker 返回。提供的 context_id 将用于检索包含应应用于参数的梯度的相应 context

参数:

context_id – 我们应该运行优化器步骤的 autograd 上下文 ID。

class torch.distributed.optim.PostLocalSGDOptimizer(optim, averager)[源代码]#

包装任意 torch.optim.Optimizer 并运行 post-local SGD。此优化器在每一步都运行本地优化器。在预热阶段之后,它会在应用本地优化器后定期平均参数。

参数:
  • optim (Optimizer) – 本地优化器。

  • averager (ModelAverager) – 用于运行 post-localSGD 算法的模型平均器实例。

示例

>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>> from torch.distributed.optim import PostLocalSGDOptimizer
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>>   PostLocalSGDState,
>>>   post_localSGD_hook,
>>> )
>>>
>>> model = nn.parallel.DistributedDataParallel(
>>>    module, device_ids=[rank], output_device=rank
>>> )
>>>
>>> # Register a post-localSGD communication hook.
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>>     optim=local_optim,
>>>     averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>>
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
>>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
>>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
>>> for step in range(0, 200):
>>>    opt.zero_grad()
>>>    loss = loss_fn(output, labels)
>>>    loss.backward()
>>>    opt.step()
load_state_dict(state_dict)[源代码]#

这与 torch.optim.Optimizerload_state_dict() 相同,但还会在提供的 state_dict 中恢复模型平均器的步数,以使其恢复到保存时的值。

如果在 state_dict 中没有 "step" 条目,则会发出警告并将模型平均器的步数初始化为 0。

state_dict()[源代码]#

这与 torch.optim.Optimizerstate_dict() 相同,但会添加一个额外条目来将模型平均器的步数记录到检查点,以确保重新加载时不会再次出现不必要的预热。

step()[源代码]#

执行一次优化步骤(参数更新)。

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[源代码]#

包装任意 optim.Optimizer 并将状态分片到组中的 ranks。

分片按照 ZeRO 中所述的方式进行。

每个 rank 上的本地优化器实例仅负责更新大约 1 / world_size 个参数,因此只需要保留 1 / world_size 个优化器状态。在本地更新参数后,每个 rank 会将其参数广播到所有其他对等节点,以使所有模型副本保持相同的状态。ZeroRedundancyOptimizer 可以与 torch.nn.parallel.DistributedDataParallel 结合使用,以减少每个 rank 的峰值内存消耗。

ZeroRedundancyOptimizer 使用排序贪婪算法在每个 rank 上打包一定数量的参数。每个参数属于一个 rank,并且不跨 rank 分割。分区是任意的,可能与参数注册或使用顺序不匹配。

参数:

params (Iterable) – 一个 Iterable,包含所有参数的 torch.Tensordict,它们将被分片到 ranks。

关键字参数:
  • optimizer_class (torch.nn.Optimizer) – 本地优化器的类。

  • process_group (ProcessGroup, optional) – torch.distributed ProcessGroup (默认:由 torch.distributed.init_process_group() 初始化 的 dist.group.WORLD)。

  • parameters_as_bucket_view (bool, optional) – 如果为 True,则参数被打包到桶中以加快通信,并且 param.data 字段指向不同偏移量的桶视图;如果为 False,则单独通信每个参数,并且每个 params.data 保持不变(默认:False)。

  • overlap_with_ddp (bool, optional) – 如果为 True,则 step()DistributedDataParallel 的梯度同步重叠;这需要(1)为 optimizer_class 参数提供一个函数式优化器,或者提供一个具有函数式等价物的优化器,并且(2)注册一个从 ddp_zero_hook.py 中的函数之一构建的 DDP 通信钩子;参数被打包到与 DistributedDataParallel 中的桶匹配的桶中,这意味着 parameters_as_bucket_view 参数将被忽略。如果为 False,则 step() 在反向传播之后独立运行(按正常情况)。(默认:False

  • **defaults – 任何尾部参数,它们将按原样传递给本地优化器。

示例

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

警告

当前,ZeroRedundancyOptimizer 要求所有传入的参数都是相同的密集类型。

警告

如果您传递 overlap_with_ddp=True,请注意以下事项:根据目前实现 DistributedDataParallelZeroRedundancyOptimizer 重叠的方式,前两个或三个训练迭代不会在优化器步骤中执行参数更新,具体取决于 static_graph=Falsestatic_graph=True,分别是。这是因为需要有关 DistributedDataParallel 使用的梯度分桶策略的信息,该策略在 static_graph=False 的情况下直到第二次正向传递才会最终确定,而在 static_graph=True 的情况下直到第三次正向传递才会最终确定。为了解决这个问题,一个选项是预置虚拟输入。

警告

ZeroRedundancyOptimizer 处于实验阶段,可能会发生变化。

add_param_group(param_group)[源代码]#

将一个参数组添加到 Optimizerparam_groups 中。

这在微调预训练网络时可能很有用,因为在训练过程中可以使冻结层可训练并添加到 Optimizer 中。

参数:

param_group (dict) – 指定要优化的参数以及特定于组的优化选项。

警告

此方法负责更新所有分片上的分片,但需要在所有 ranks 上调用。在 ranks 的子集上调用此方法将导致训练挂起,因为通信原语是根据管理参数调用的,并且期望所有 ranks 都参与同一组参数。

consolidate_state_dict(to=0)[源代码]#

在目标 rank 上合并一组(每个 rank 一个)state_dict

参数:

to (int) – 接收优化器状态的 rank(默认:0)。

抛出:

RuntimeError – 如果 overlap_with_ddp=True 且在 ZeroRedundancyOptimizer 实例完全初始化(一旦 DistributedDataParallel 梯度桶重建完成)之前调用此方法。

警告

这需要在所有 ranks 上调用。

property join_device: device#

返回默认设备。

join_hook(**kwargs)[源代码]#

返回 ZeRO join 钩子。

它通过在优化器步骤中影射集体通信来支持在不均匀输入上进行训练。

调用此钩子之前必须正确设置梯度。

参数:

kwargs (dict) – 一个 dict,其中包含任何关键字参数,用于在运行时修改 join 钩子的行为;共享同一 join 上下文管理器的所有 Joinable 实例都会收到 kwargs 的相同值。

此钩子不支持任何关键字参数;即 kwargs 未使用。

property join_process_group: Any#

返回进程组。

load_state_dict(state_dict)[源代码]#

从输入的 state_dict 加载属于给定 rank 的状态,并根据需要更新本地优化器。

参数:

state_dict (dict) – 优化器状态;应为调用 state_dict() 返回的对象。

抛出:

RuntimeError – 如果 overlap_with_ddp=True 且在 ZeroRedundancyOptimizer 实例完全初始化(一旦 DistributedDataParallel 梯度桶重建完成)之前调用此方法。

state_dict()[源代码]#

返回此 rank 已知的最后一个全局优化器状态。

抛出:

RuntimeError – 如果 overlap_with_ddp=True 且在 ZeroRedundancyOptimizer 实例完全初始化(一旦 DistributedDataParallel 梯度桶重建完成)之前调用此方法;或者如果在调用 consolidate_state_dict() 之前调用此方法。

返回类型:

dict[str, Any]

step(closure=None, **kwargs)[源代码]#

执行一次优化步骤并同步所有 ranks 上的参数。

参数:

closure (Callable) – 一个重新评估模型并返回损失的闭包;对大多数优化器来说是可选的。

返回:

根据底层本地优化器,可选的损失。

返回类型:

float | None

注意

任何额外的参数都将按原样传递给基础优化器。