DDP 通信钩子#
创建于:2025 年 6 月 6 日 | 最后更新于:2025 年 6 月 6 日
DDP 通信钩子是一个通用接口,用于通过覆盖 DistributedDataParallel 中的 vanilla allreduce 来控制如何在工作进程之间通信梯度。提供了几个内置通信钩子,用户可以轻松应用这些钩子来优化通信。此外,该钩子接口还可以支持用户定义的通信策略,以应对更高级的用例。
如何使用通信钩子?#
要使用通信钩子,用户只需在训练循环之前让 DDP 模型注册该钩子,如下所示。
torch.nn.parallel.DistributedDataParallel.register_comm_hook()
通信钩子操作什么?#
通信钩子提供了一种灵活的方式来对梯度进行 allreduce。因此,它主要在 allreduce 之前对每个副本上的梯度进行操作,这些梯度被分桶以增加通信和计算之间的重叠。特别是,torch.distributed.GradBucket
表示一个梯度张量桶,用于进行 allreduce。
- class torch.distributed.GradBucket#
这个类主要将一个扁平化的梯度张量(由
buffer()
返回)传递给 DDP 通信钩子。该张量可以进一步分解为该桶中的每个参数张量列表(由get_per_parameter_tensors()
返回)以应用层级操作。
- torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket) int #
警告
由于桶在第一次迭代后重建,因此不应依赖训练开始时的索引。
- 返回
存储一些连续层梯度的桶的索引。所有梯度都已分桶。
- torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket) torch.Tensor #
- 返回
一个扁平化的 1D
torch.Tensor
缓冲区,可以进一步分解为该桶中每个参数张量的列表。
- torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor] #
- 返回
一个
torch.Tensor
列表。列表中的每个张量都对应一个梯度。
- torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket) bool #
- 返回
此桶是否是迭代中最后一个要进行 allreduce 的桶。这也意味着此桶对应于前向传递中的前几层。
- torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: torch.Tensor) None #
用输入张量缓冲区替换 GradBucket 中的张量。
- torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor] #
- 返回
一个
torch.Tensor
列表。列表中的每个张量都对应一个模型参数。
默认通信挂钩#
默认通信挂钩是简单的 **无状态** 挂钩,因此 register_comm_hook
中的输入状态要么是进程组,要么是 None
。输入 bucket
是一个 torch.distributed.GradBucket
对象。
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[source]#
使用
GradBucket
张量调用allreduce
。一旦梯度张量在所有 worker 之间聚合,它的
then
回调将取平均值并返回结果。如果用户注册此 DDP 通信挂钩,DDP 结果预计与未注册挂钩的情况相同。因此,这不会改变 DDP 的行为,用户可以将其作为参考或修改此挂钩以记录有用信息或任何其他目的,同时不影响 DDP 行为。
- 示例:
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[source]#
通过将
GradBucket
转换为torch.float16
并除以进程组大小进行压缩。此 DDP 通信挂钩实现了一种简单的梯度压缩方法,该方法将
GradBucket
张量转换为半精度浮点格式 (torch.float16
),然后将其除以进程组大小。它对这些float16
梯度张量进行 allreduce。一旦压缩的梯度张量被 allreduce,链式回调decompress
将其转换回输入数据类型(例如float32
)。- 示例:
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[source]#
警告:此 API 处于实验阶段,它需要 NCCL 版本晚于 2.9.6。
此 DDP 通信挂钩实现了一种简单的梯度压缩方法,该方法将
GradBucket
张量转换为半精度 Brain 浮点格式 (torch.bfloat16
),然后将其除以进程组大小。它对这些bfloat16
梯度张量进行 allreduce。一旦压缩的梯度张量被 allreduce,链式回调decompress
将其转换回输入数据类型(例如float32
)。- 示例:
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
此外,还提供了一个通信挂钩包装器,以支持 fp16_compress_hook()
或 bf16_compress_hook()
作为包装器,可以与其他通信挂钩结合使用。
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[source]#
将输入张量转换为
torch.float16
,将挂钩结果转换回输入数据类型。此包装器将给定 DDP 通信挂钩的输入梯度张量转换为半精度浮点格式 (
torch.float16
),并将给定挂钩的结果张量转换回输入数据类型,例如float32
。因此,fp16_compress_hook
等效于fp16_compress_wrapper(allreduce_hook)
。- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
- 返回类型
Callable[[Any, GradBucket], Future[Tensor]]
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[source]#
警告:此 API 处于实验阶段,它需要 NCCL 版本晚于 2.9.6。
此包装器将给定 DDP 通信挂钩的输入梯度张量转换为半精度 Brain 浮点格式 (
torch.bfloat16
),并将给定挂钩的结果张量转换回输入数据类型,例如float32
。因此,
bf16_compress_hook
等效于bf16_compress_wrapper(allreduce_hook)
。- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
- 返回类型
Callable[[Any, GradBucket], Future[Tensor]]
PowerSGD 通信挂钩#
PowerSGD (Vogels et al., NeurIPS 2019) 是一种梯度压缩算法,可以提供非常高的压缩率并加速受带宽限制的分布式训练。该算法需要维护一些超参数和内部状态。因此,PowerSGD 通信挂钩是一个 **有状态** 挂钩,用户需要提供如下定义的状态对象。
PowerSGD 状态#
- class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[source]#
在训练期间存储算法的超参数和所有梯度的内部状态。
特别是,
matrix_approximation_rank
和start_powerSGD_iter
是用户应该调整的主要超参数。为了性能,我们建议保持二进制超参数use_error_feedback
和warm_start
为开启状态。matrix_approximation_rank
控制压缩低秩张量的大小,这决定了压缩率。秩越低,压缩越强。1.1. 如果
matrix_approximation_rank
过低,模型可能需要更多的训练步骤才能达到完整的模型质量,或者永远无法达到并导致精度损失。1.2.
matrix_approximation_rank
的增加会大幅增加压缩的计算成本,并且精度可能不会超过某个matrix_approximation_rank
阈值而进一步提高。
为了调整
matrix_approximation_rank
,我们建议从 1 开始,并以 2 的因子增加(例如指数网格搜索,1、2、4…),直到达到令人满意的精度。通常只使用较小的值 1-4。对于某些 NLP 任务(如原始论文附录 D 所示),此值已增加到 32。start_powerSGD_iter
将 PowerSGD 压缩推迟到步骤start_powerSGD_iter
之后,而在此之前运行 vanilla allreduce。这种 **vanilla allreduce + PowerSGD** 的混合方案可以有效地提高精度,即使使用相对较小的matrix_approximation_rank
。这是因为训练阶段开始时通常对不准确的梯度非常敏感,过早压缩梯度可能会使训练迅速进入次优轨迹,从而对精度造成不可恢复的影响。
为了调整
start_powerSGD_iter
,我们建议从总训练步骤的 10% 开始,并逐渐增加,直到达到令人满意的精度。如果训练中存在预热阶段,则start_powerSGD_iter
通常不应小于预热步骤数。min_compression_rate
是层被压缩时所需的最小压缩率。由于压缩引起的计算开销,只有当带宽有足够的节省时才值得压缩张量,其中(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols
。如果无法满足指定的压缩率阈值,则张量将直接进行 allreduce 而不进行压缩。
一旦 PowerSGD 压缩开始,压缩统计信息将每
compression_stats_logging_frequency
次迭代记录一次。orthogonalization_epsilon
可以是一个非常小的值(例如 1e-8),在正交化步骤中添加到每个归一化矩阵列中,以防止在任何列全部为 0 时出现除零错误。如果这已经可以防止(例如,通过批量归一化),建议为了精度将 epsilon 设置为 0。batch_tensors_with_same_shape
控制是否通过批量操作压缩和解压缩具有相同形状的张量以实现更高的并行度。请注意,您还应该增加 bucket 大小(即 DDP 构造函数中的bucket_cap_mb
参数),以便在同一个 bucket 中出现更多相同形状的张量,但这可能会减少计算和通信之间的重叠,并由于堆叠相同形状的张量而增加内存占用。如果压缩/解压缩计算是瓶颈,则设置为True
。
警告
如果启用错误反馈或热启动,DDP 中允许的
start_powerSGD_iter
的最小值为 2。这是因为 DDP 在迭代 1 时有另一个内部优化会重建 bucket,这可能会与重建过程之前记忆的任何张量发生冲突。
PowerSGD 挂钩#
警告
PowerSGD 通常需要与模型梯度相同大小的额外内存,以启用错误反馈,这可以补偿有偏的压缩通信并提高精度。
警告
PowerSGD 挂钩可能与 Apex 自动混合精度包 冲突。请改用 PyTorch 原生自动混合精度包。
- torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[source]#
实现 PowerSGD 算法。
此 DDP 通信挂钩实现了 论文 中描述的 PowerSGD 梯度压缩算法。一旦梯度张量在所有 worker 之间聚合,此挂钩将按如下方式应用压缩
将输入的扁平化 1D 梯度张量视为每个参数张量的列表,并将所有张量分为两组
1.1. 在 allreduce 之前应该压缩的张量,因为压缩可以节省足够的带宽。
1.2. 其余张量将直接进行 allreduce 而不进行压缩,包括所有向量张量(用于偏差)。
处理未压缩张量
2.1. 为这些未压缩张量分配连续内存,并批量 allreduce 所有未压缩张量,不进行压缩;
2.2. 将单个未压缩张量从连续内存复制回输入张量。
处理应通过 PowerSGD 压缩的张量
3.1. 对于每个张量 M,创建两个低秩张量 P 和 Q 以分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;
3.2. 计算 Ps 中的每个 P,它等于 MQ;
3.3. 批量 allreduce Ps;
3.4. 正交化 Ps 中的每个 P;
3.5. 计算 Qs 中的每个 Q,它近似等于 M^TP;
3.6. 批量 allreduce Qs;
3.7. 计算所有压缩张量中的每个 M,它近似等于 PQ^T。
请注意,此通信挂钩在第一个
state.start_powerSGD_iter
迭代中强制执行 vanilla allreduce。这不仅使用户能够更好地控制加速和精度之间的权衡,而且还有助于抽象出 DDP 内部优化的一些复杂性,以供未来的通信挂钩开发人员使用。- 参数
state (PowerSGDState) – 用于配置压缩率和支持错误反馈、热启动等的状态信息。要调整压缩配置,主要需要调整
matrix_approximation_rank
、start_powerSGD_iter
和min_compression_rate
。bucket (dist.GradBucket) – 存储一个 1D 扁平化梯度张量的桶,该张量批量处理多个每变量张量。请注意,由于 DDP comm 挂钩仅支持单进程单设备模式,因此此桶中只存储一个张量。
- 返回
通信的未来处理程序,它原地更新梯度。
- 返回类型
- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10, min_compression_rate=0.5) >>> ddp_model.register_comm_hook(state, powerSGD_hook)
- torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[source]#
实现简化的 PowerSGD 算法。
此 DDP 通信挂钩实现了 论文 中描述的简化 PowerSGD 梯度压缩算法。此变体不逐层压缩梯度,而是压缩批量处理所有梯度的扁平化输入张量。因此,它比
powerSGD_hook()
**更快**,但通常会导致 **精度低得多**,除非matrix_approximation_rank
为 1。警告
在此处增加
matrix_approximation_rank
不一定会提高精度,因为不对齐列/行批量处理每个参数张量可能会破坏低秩结构。因此,用户应始终首先考虑powerSGD_hook()
,仅当matrix_approximation_rank
为 1 时才能达到令人满意的精度时才考虑此变体。一旦梯度张量在所有 worker 之间聚合,此挂钩将按如下方式应用压缩
将输入的扁平化 1D 梯度张量视为一个带 0 填充的方形张量 M;
创建两个低秩张量 P 和 Q 以分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;
计算 P,它等于 MQ;
allreduce P;
正交化 P;
计算 Q,它近似等于 M^TP;
allreduce Q;
计算 M,它近似等于 PQ^T。
将输入张量截断到原始长度。
请注意,此通信挂钩在第一个
state.start_powerSGD_iter
迭代中强制执行 vanilla allreduce。这不仅使用户能够更好地控制加速和精度之间的权衡,而且还有助于抽象出 DDP 内部优化的一些复杂性,以供未来的通信挂钩开发人员使用。- 参数
state (PowerSGDState) – 用于配置压缩率和支持错误反馈、热启动等的状态信息。要调整压缩配置,主要需要调整
matrix_approximation_rank
和start_powerSGD_iter
。bucket (dist.GradBucket) – 存储一个 1D 扁平化梯度张量的桶,该张量批量处理多个每变量张量。请注意,由于 DDP comm 挂钩仅支持单进程单设备模式,因此此桶中只存储一个张量。
- 返回
通信的未来处理程序,它原地更新梯度。
- 返回类型
- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
调试通信挂钩#
顾名思义,调试通信挂钩 **仅** 用于调试和性能优化目的。
警告
调试通信挂钩不一定输出正确的结果。
- torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[source]#
返回一个包装输入的 Future,因此它是一个不产生任何通信开销的空操作。
此挂钩 **仅** 应用于 allreduce 优化的余量分析,而不是正常的梯度同步。例如,如果注册此挂钩后训练时间仅观察到不到 10% 的加速,则通常意味着 allreduce 在此情况下不是性能瓶颈。如果 GPU 跟踪不易检索或跟踪分析因 allreduce 和计算之间的重叠或跨 rank 的不同步等因素而复杂化,则此类工具特别有用。
- 示例:
>>> ddp_model.register_comm_hook(None, noop_hook)
通信挂钩的检查点#
有状态通信挂钩可以作为模型检查点的一部分保存,以实现训练器重启。为了使挂钩可序列化,应定义 __setstate__
和 __getstate__
。
警告
__getstate__
应从返回的字典中排除不可序列化的属性。
警告
__setstate__
应正确初始化非序列化属性,这些属性已从提供的 state
中排除。
PowerSGDState
已实现 __setstate__
和 __getstate__
,可作为参考。
- class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[source]
这是一个保存和重新加载 PowerSGD 状态和挂钩的简单端到端示例。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(24,24)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(24,12)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def run_demo(demo_fn, world_size):
mp.spawn(
demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
def demo_serialization(rank, world_size):
setup(rank, world_size)
CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"
model = SimpleModel().to(rank)
ddp_model = DistributedDataParallel(model, device_ids=[rank])
powersgd_hook = powerSGD.powerSGD_hook
powersgd_state = powerSGD.PowerSGDState(process_group=None)
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
state = {
'state_dict': ddp_model.state_dict(),
'comm_hook': powersgd_hook,
'comm_hook_state': powersgd_state}
if rank == 0:
torch.save(state, CHECKPOINT)
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
checkpoint = torch.load(CHECKPOINT, map_location=map_location)
new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
new_ddp_model.load_state_dict(checkpoint['state_dict'])
powersgd_hook = checkpoint['comm_hook']
powersgd_state = checkpoint['comm_hook_state']
new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
if rank == 0:
os.remove(CHECKPOINT)
cleanup()
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_serialization, world_size)