评价此页

DDP 通信钩子#

创建于:2025 年 6 月 6 日 | 最后更新于:2025 年 6 月 6 日

DDP 通信钩子是一个通用的接口,用于通过覆盖 DistributedDataParallel 中的标准 allreduce 来控制跨工作节点(workers)通信梯度的方式。提供了几个内置的通信钩子,用户可以轻松应用这些钩子来优化通信。此外,钩子接口还可以支持用户自定义的通信策略,以满足更高级的用例。

如何使用通信钩子?#

要使用通信钩子,用户只需在训练循环开始之前,让 DDP 模型注册该钩子,如下所示。

torch.nn.parallel.DistributedDataParallel.register_comm_hook()

通信钩子操作什么?#

通信钩子提供了一种灵活的方式来 allreduce 梯度。因此,它主要在 allreduce 之前对每个副本上的梯度进行操作,这些梯度会被分桶(bucketized)以增加通信和计算之间的重叠。特别地,torch.distributed.GradBucket 代表了一个待 allreduce 的梯度张量集合。

class torch.distributed.GradBucket#

此类主要将一个扁平化的梯度张量(由 buffer() 返回)传递给 DDP 通信钩子。该张量可以进一步分解为该分桶内按参数划分的张量列表(由 get_per_parameter_tensors() 返回),以便应用层级操作。

torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket) int#

警告

由于分桶在第一次迭代后会重建,因此在训练开始时,不应依赖于索引。

返回

存储了几个连续层梯度的分桶的索引。所有梯度都被分桶。

torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket) torch.Tensor#
返回

一个扁平化的 1D torch.Tensor 缓冲区,可以进一步分解为该分桶内按参数划分的张量列表。

torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]#
返回

一个 torch.Tensor 列表。列表中的每个张量对应一个梯度。

torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket) bool#
返回

此分桶是否是迭代中最后一个要 allreduce 的分桶。这也意味着此分桶对应于前向传播中的前几个层。

torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: torch.Tensor) None#

用输入的张量缓冲区替换分桶中的张量。

torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]#
返回

一个 torch.Tensor 列表。列表中的每个张量对应一个模型参数。

默认通信钩子#

默认通信钩子是简单的**无状态**钩子,因此 register_comm_hook 中的 state 参数要么是进程组(process group),要么是 None。输入的 bucket 是一个 torch.distributed.GradBucket 对象。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[source]#

使用 GradBucket 张量调用 allreduce

一旦梯度张量在所有工作节点上聚合完毕,其 then 回调函数会计算平均值并返回结果。

如果用户注册了这个 DDP 通信钩子,DDP 的结果预期与未注册钩子时相同。因此,这不会改变 DDP 的行为,用户可以将其用作参考,或者修改此钩子以记录有用的信息或用于其他目的,同时不影响 DDP 的行为。

示例:
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
返回类型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[source]#

通过将 GradBucket 转换为 torch.float16 并除以进程组大小来进行压缩。

这个 DDP 通信钩子实现了一种简单的梯度压缩方法,它将 GradBucket 张量转换为半精度浮点格式 (torch.float16),然后除以进程组大小。它对这些 float16 梯度张量进行 allreduce。一旦压缩的梯度张量 allreduce 完成,链式回调函数 decompress 会将其转换回输入数据类型(例如 float32)。

示例:
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
返回类型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[source]#

警告:此 API 尚处于实验阶段,需要 NCCL 版本大于 2.9.6。

这个 DDP 通信钩子实现了一种简单的梯度压缩方法,它将 GradBucket 张量转换为半精度 Brain 浮点格式 (torch.bfloat16),然后除以进程组大小。它对这些 bfloat16 梯度张量进行 allreduce。一旦压缩的梯度张量 allreduce 完成,链式回调函数 decompress 会将其转换回输入数据类型(例如 float32)。

示例:
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
返回类型

Future[Tensor]

此外,还提供了一个通信钩子包装器,用于支持 fp16_compress_hook()bf16_compress_hook() 作为包装器,可以与其他通信钩子结合使用。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[source]#

将输入张量转换为 torch.float16,将钩子的结果转换回输入数据类型。

此包装器将给定 DDP 通信钩子的输入梯度张量转换为半精度浮点格式 (torch.float16),并将给定钩子的结果张量转换回输入数据类型,例如 float32。因此,fp16_compress_hook 等同于 fp16_compress_wrapper(allreduce_hook)

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
返回类型

Callable[[Any, GradBucket], Future[Tensor]]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[source]#

警告:此 API 尚处于实验阶段,需要 NCCL 版本大于 2.9.6。

此包装器将给定 DDP 通信钩子的输入梯度张量转换为半精度 Brain 浮点格式 (torch.bfloat16),并将给定钩子的结果张量转换回输入数据类型,例如 float32

因此,bf16_compress_hook 等同于 bf16_compress_wrapper(allreduce_hook)

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
返回类型

Callable[[Any, GradBucket], Future[Tensor]]

PowerSGD 通信钩子#

PowerSGD(Vogels 等人,NeurIPS 2019)是一种梯度压缩算法,可以提供非常高的压缩率并加速带宽受限的分布式训练。该算法需要同时维护一些超参数和内部状态。因此,PowerSGD 通信钩子是一个**有状态**的钩子,用户需要提供一个如下定义的 state 对象。

PowerSGD State#

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[source]#

存储算法的超参数和所有梯度在训练期间的内部状态。

特别地,matrix_approximation_rankstart_powerSGD_iter 是用户应该调整的主要超参数。为了提高性能,建议保持二元超参数 use_error_feedbackwarm_start 为 True。

  1. matrix_approximation_rank 控制压缩低秩张量的大小,这决定了压缩率。秩越低,压缩越强。

    1.1. 如果 matrix_approximation_rank 太低,完整的模型质量需要更多训练步骤才能达到,或者永远无法达到并导致准确率损失。

    1.2. 增加 matrix_approximation_rank 会显著增加压缩的计算成本,并且准确率可能在超过某个 matrix_approximation_rank 阈值后不再进一步提高。

为了调整 matrix_approximation_rank,我们建议从 1 开始,并以 2 的倍数增加(类似指数网格搜索,1, 2, 4, …),直到达到满意的准确率。通常只使用较小的值 1-4。对于某些 NLP 任务(如原始论文附录 D 所示),此值已增加到 32。

  1. start_powerSGD_iter 将 PowerSGD 压缩推迟到第 start_powerSGD_iter 步,并在第 start_powerSGD_iter 步之前运行标准 allreduce。这种**标准 allreduce + PowerSGD** 的混合方案可以有效提高准确率,即使使用相对较小的 matrix_approximation_rank。这是因为训练的初始阶段通常对不精确的梯度非常敏感,过早压缩梯度可能会使训练很快进入次优轨迹,从而对准确率产生不可逆转的影响。

为了调整 start_powerSGD_iter,我们建议从总训练步数的 10% 开始,并将其增加直到达到满意的准确率。如果训练中有预热阶段,start_powerSGD_iter 通常不应小于预热步数。

  1. min_compression_rate 是压缩层所需的最低压缩率。由于压缩会带来计算开销,只有当带宽节省足够大时,张量才值得压缩,其中 (num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols。如果指定的压缩率阈值无法满足,该张量将被直接 allreduce 而不进行压缩。

一旦 PowerSGD 压缩开始,每隔 compression_stats_logging_frequency 次迭代都会记录压缩统计信息。

  1. orthogonalization_epsilon 可以是一个非常小的值(例如 1e-8),添加到正交化步骤中的每个归一化矩阵列中,以防止在任何列全为 0 时发生除零错误。如果这个问题已经被防止(例如通过批量归一化),则建议将 epsilon 设置为 0 以保证准确性。

  2. batch_tensors_with_same_shape 控制是否将相同形状的张量进行批处理操作以压缩和解压缩,以实现更高的并行度。请注意,您还应该增加分桶大小(即 DDP 构造函数中的 bucket_cap_mb 参数),以便在同一个分桶中出现更多相同形状的张量,但这可能会降低计算和通信之间的重叠,并因堆叠相同形状的张量而增加内存占用。如果压缩/解压缩计算是瓶颈,则将其设置为 True

警告

如果启用了错误反馈或预热,DDP 中允许的 start_powerSGD_iter 的最小值是 2。这是因为 DDP 中还有一个内部优化会在迭代 1 时重建分桶,这可能会与在重建过程之前记住的任何张量发生冲突。

PowerSGD 钩子#

警告

PowerSGD 通常需要与模型梯度相同大小的额外内存来支持错误反馈,这可以补偿有偏的压缩通信并提高准确性。

警告

PowerSGD 钩子可能与 Apex 自动混合精度包冲突。请改用 PyTorch 的原生自动混合精度包

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[source]#

实现 PowerSGD 算法。

这个 DDP 通信钩子实现了 论文 中描述的 PowerSGD 梯度压缩算法。一旦梯度张量在所有工作节点上聚合完毕,该钩子按如下方式应用压缩:

  1. 将输入的扁平化一维梯度张量视为按参数划分的张量列表,并将所有张量分为两组:

    1.1. 在 allreduce 之前应被压缩的张量,因为压缩可以在带宽节省方面提供足够的收益。

    1.2. 其余的张量将被直接 allreduce 而不压缩,包括所有向量张量(用于偏置)。

  2. 处理未压缩的张量

    2.1. 为这些未压缩的张量分配连续内存,并将所有未压缩的张量作为一批进行 allreduce,不进行压缩;

    2.2. 将单个未压缩的张量从连续内存复制回输入张量。

  3. 处理应被 PowerSGD 压缩的张量

    3.1. 对于每个张量 M,创建两个低秩张量 P 和 Q 来分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;

    3.2. 计算 Ps 中的每个 P,它等于 MQ;

    3.3. 将 Ps 作为一批进行 allreduce;

    3.4. 正交化 Ps 中的每个 P;

    3.5. 计算 Qs 中的每个 Q,它约等于 M^TP;

    3.6. 将 Qs 作为一批进行 allreduce;

    3.7. 计算压缩张量中的每个 M,它约等于 PQ^T。

请注意,此通信钩子在前 state.start_powerSGD_iter 次迭代中强制执行标准 allreduce。这不仅使用户能够更好地控制速度提升和准确性之间的权衡,还有助于为未来的通信钩子开发者抽象化 DDP 的内部优化。

参数
  • state (PowerSGDState) – 用于配置压缩率并支持错误反馈、预热等的 state 信息。要调整压缩配置,主要需要调整 matrix_approximation_rankstart_powerSGD_itermin_compression_rate

  • bucket (dist.GradBucket) – 存储批处理多个按变量张量的 1D 扁平化梯度张量的分桶。请注意,由于 DDP comm hook 只支持单进程单设备模式,因此此分桶中只存储一个张量。

返回

通信的 Future 处理程序,它会就地更新梯度。

返回类型

Future[Tensor]

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
                          start_powerSGD_iter=10, min_compression_rate=0.5)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[source]#

实现简化的 PowerSGD 算法。

这个 DDP 通信钩子实现了 论文 中描述的简化的 PowerSGD 梯度压缩算法。此变体不是逐层压缩梯度,而是压缩批处理所有梯度的扁平化输入张量。因此,它比 powerSGD_hook() **更快**,但通常结果是**准确率低得多**,除非 matrix_approximation_rank 为 1。

警告

在这里增加 matrix_approximation_rank 可能不一定会提高准确率,因为在没有行/列对齐的情况下批处理按参数张量可能会破坏低秩结构。因此,用户应始终首先考虑 powerSGD_hook(),只有当 matrix_approximation_rank 为 1 时能达到令人满意的准确率时,才考虑此变体。

一旦梯度张量在所有工作节点上聚合完毕,该钩子按如下方式应用压缩:

  1. 将输入的扁平化一维梯度张量视为一个带有 0 填充的方形张量 M;

  2. 创建两个低秩张量 P 和 Q 来分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;

  3. 计算 P,它等于 MQ;

  4. allreduce P;

  5. 正交化 P;

  6. 计算 Q,它约等于 M^TP;

  7. allreduce Q;

  8. 计算 M,它约等于 PQ^T。

  9. 将输入张量截断到原始长度。

请注意,此通信钩子在前 state.start_powerSGD_iter 次迭代中强制执行标准 allreduce。这不仅使用户能够更好地控制速度提升和准确性之间的权衡,还有助于为未来的通信钩子开发者抽象化 DDP 的内部优化。

参数
  • state (PowerSGDState) – 用于配置压缩率并支持错误反馈、预热等的 state 信息。要调整压缩配置,主要需要调整 matrix_approximation_rankstart_powerSGD_iter

  • bucket (dist.GradBucket) – 存储批处理多个按变量张量的 1D 扁平化梯度张量的分桶。请注意,由于 DDP comm hook 只支持单进程单设备模式,因此此分桶中只存储一个张量。

返回

通信的 Future 处理程序,它会就地更新梯度。

返回类型

Future[Tensor]

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

调试通信钩子#

顾名思义,调试通信钩子**仅**用于调试和性能优化目的。

警告

调试通信钩子不一定输出正确的结果。

torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[source]#

返回一个包装输入张量的 Future,因此它是一个无操作(no-op),不会产生任何通信开销。

此钩子**仅**用于 allreduce 优化的余量分析,而不是正常的梯度同步。例如,如果注册此钩子后训练时间仅观察到不到 10% 的加速,这通常意味着 allreduce 在此情况下不是性能瓶颈。这种仪器化在 GPU 轨迹难以检索或轨迹分析因 allreduce 与计算的重叠或跨进程的失步等因素而变得复杂时尤其有用。

示例:
>>> ddp_model.register_comm_hook(None, noop_hook)
返回类型

Future[Tensor]

通信钩子的检查点#

有状态的通信钩子可以作为模型检查点的一部分进行保存,以实现训练器的重新启动。要使钩子可序列化,应定义 __setstate____getstate__

警告

__getstate__ 应从返回的字典中排除不可序列化属性。

警告

__setstate__ 应正确初始化从提供的 state 中排除的不可序列化属性。

PowerSGDState 已实现 __setstate____getstate__,可作为参考。

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[source]
__getstate__()[source]#

返回一个将被序列化并保存的 Dict[str, Any]

process_group 不可序列化,因此从返回的状态中排除。

__setstate__(state)[source]#

采用提供的 state 并将其设置为此 PowerSGDState 实例。

process_group 被设置为默认值。

下面是一个保存和重新加载 PowerSGD state 和 hook 的简单端到端示例。


import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24,24)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(24,12)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def run_demo(demo_fn, world_size):
    mp.spawn(
        demo_fn,
        args=(world_size,),
        nprocs=world_size,
        join=True)

def demo_serialization(rank, world_size):
    setup(rank, world_size)

    CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"

    model = SimpleModel().to(rank)
    ddp_model = DistributedDataParallel(model, device_ids=[rank])

    powersgd_hook = powerSGD.powerSGD_hook
    powersgd_state = powerSGD.PowerSGDState(process_group=None)

    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    state = {
        'state_dict': ddp_model.state_dict(),
        'comm_hook': powersgd_hook,
        'comm_hook_state': powersgd_state}

    if rank == 0:
        torch.save(state, CHECKPOINT)

    dist.barrier()
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    checkpoint = torch.load(CHECKPOINT, map_location=map_location)

    new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
    new_ddp_model.load_state_dict(checkpoint['state_dict'])
    powersgd_hook = checkpoint['comm_hook']
    powersgd_state = checkpoint['comm_hook_state']

    new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    if rank == 0:
        os.remove(CHECKPOINT)

    cleanup()

if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_serialization, world_size)

致谢#

非常感谢 PowerSGD 论文作者 **Thijs Vogels** 对 PowerSGD 通信钩子的代码评审,以及 比较实验,这些实验表明 PowerSGD 通信钩子的性能与原始 论文 中的实现相当。