评价此页

torch.distributed.fsdp.fully_shard#

创建于:2024年12月04日 | 最后更新于:2025年06月16日

PyTorch FSDP2 (fully_shard)#

PyTorch FSDP2 (RFC) 提供了一个完全分片数据并行 (FSDP) 实现,旨在为高性能即时模式提供支持,同时使用按参数分片以提高可用性。

  • 有关更多信息,请参阅 FSDP2 入门 教程。

  • 如果您当前正在使用 FSDP1,请考虑使用我们的 迁移指南 迁移到 FSDP2。

fully_shard(model) 的用户约定如下:

  • 对于模型初始化,fully_shardmodel.parameters() 从普通 torch.Tensor 就地转换为 DTensor。参数根据设备网格移动到相应的设备。

  • 在前向和反向传播之前,前向/反向钩子负责对参数进行全收集,并将 model.parameters()DTensor 转换为普通 torch.Tensor

  • 在前向和反向传播之后,后向/反向钩子释放未分片的参数(不需要通信),并将 model.parameters() 从普通 torch.Tensor 转换回 DTensor

  • 对于优化器,它必须使用 DTensor model.parameters() 进行初始化,并且优化器步骤应在 DTensor 参数上执行。

  • 调用 model(input) 而不是 model.forward(input) 以触发前向钩子来全收集参数。为了使 model.forward(input) 工作,用户必须显式调用 model.unshard() 或使用 register_fsdp_forward_method(model, "forward") 注册前向方法以进行钩子操作。

  • fully_shard 将参数分组在一起进行一次全收集。用户应以自底向上的方式应用 fully_shard。例如,在 Transformer 模型中,应在将其应用于根模型之前,将 fully_shard 应用于每个层。当应用于根模型时,fully_shard 将从每个层中排除 model.parameters(),并将剩余的参数(例如,嵌入、输出投影)分组到一个全收集组中。

  • type(model)FSDPModule 就地“联合”。例如,如果模型最初是 nn.Linear 类型,则 fully_shard 会将 type(model)nn.Linear 就地更改为 FSDPLinearFSDPLinearnn.LinearFSDPModule 的一个实例。它保留了 nn.Linear 的所有方法,同时还在 FSDPModule 下公开了 FSDP2 特定的 API,例如 reshard()unshard()

  • 参数的完全限定名称 (FQNs) 保持不变。如果调用 model.state_dict(),在应用 fully_shard 前后,FQNs 是相同的。这是因为 fully_shard 不会包装模块,而只是向原始模块注册钩子。

与 PyTorch FSDP1 (FullyShardedDataParallel) 相比

  • FSDP2 使用基于 DTensor 的 dim-0 每参数分片,与 FSDP1 的扁平参数分片相比,提供了更简单的分片表示,同时保持了相似的吞吐量性能。更具体地说,FSDP2 使用 torch.chunk(dim=0) 在 dim-0 上将每个参数分块到数据并行工作器中,而 FSDP1 则将一组张量扁平化、连接并分块在一起,这使得推理每个工作器上存在哪些数据以及重新分片到不同并行度变得复杂。每参数分片提供了更直观的用户体验,放宽了对冻结参数的约束,并允许无通信(分片)状态字典,否则在 FSDP1 中需要全部聚合。

  • FSDP2 实现了不同的内存管理方法来处理多流使用,避免了 torch.Tensor.record_stream。这确保了确定性和预期的内存使用,并且不需要像 FSDP1 的 limit_all_gathers=True 那样阻塞 CPU。

  • FSDP2 暴露了用于手动控制预取和集体调度的 API,允许高级用户进行更多自定义。有关详细信息,请参阅下面 FSDPModule 上的方法。

  • FSDP2 简化了一些 API 表面:例如,FSDP2 不直接支持完整状态字典。相反,用户可以使用 DTensor API(如 DTensor.full_tensor())或使用更高级别的 API(如 PyTorch Distributed Checkpoint 的分布式状态字典 API)将包含 DTensor 的分片状态字典重新分片为完整状态字典。此外,还删除了一些其他参数;有关详细信息,请参阅此处

前端 API 是 fully_shard,可以在 module 上调用

torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=None, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy(), ignored_params=None)[source]#

module 应用全分片数据并行 (FSDP),FSDP 将模块参数、梯度和优化器状态分片到数据并行工作器中以节省内存,代价是通信。

在初始化时,FSDP 根据 mesh 将模块的参数分片到数据并行工作器中。在正向传播之前,FSDP 在数据并行工作器之间全部聚合分片参数以获取用于正向计算的未分片参数。如果 reshard_after_forwardTrue,则 FSDP 在正向传播后释放未分片参数,并在梯度计算之前的反向传播中重新全部聚合它们。在梯度计算之后,FSDP 释放未分片参数并将未分片梯度缩减散布到数据并行工作器中。

此实现将分片参数表示为在 dim-0 上分片的 DTensor,而未分片参数将类似于 module 上的原始参数(例如,如果原始为 torch.Tensor,则为 torch.Tensor)。module 上的模块 正向预钩子 会全部聚合参数,而 module 上的模块 正向钩子 会释放它们(如果需要)。类似的反向钩子会全部聚合参数,然后释放参数并缩减散布梯度。

由于将多个张量组合在一起进行一次集合对于通信效率至关重要,因此此实现将此分组作为一流公民。在 module 上调用 fully_shard() 会构建一个组,其中包括 module.parameters() 中的参数,但那些已从子模块的早期调用分配到组的参数除外。这意味着 fully_shard() 应该在模型的底部向上调用。每个组的参数在一次集合中全部聚合,其梯度在一次集合中缩减散布。将模型划分为多个组(“逐层”)可以实现峰值内存节省和通信/计算重叠。用户通常不应仅在最顶层的根模块上调用 fully_shard()

参数
  • module (Union[nn.Module, List[nn.Module]) – 要使用 FSDP 分片并分组进行通信的模块或模块列表。

  • mesh (Optional[DeviceMesh]) – 此数据并行网格定义了分片和设备。如果为 1D,则参数通过 (Shard(0),) 位置在 1D 网格 (FSDP) 中完全分片。如果为 2D,则参数通过 (Replicate(), Shard(0)) 位置在第 1 维上分片并在第 0 维上复制 (HSDP)。网格的设备类型指定了用于通信的设备类型;如果是 CUDA 或类似 CUDA 的设备类型,则使用当前设备。

  • reshard_after_forward (Optional[Union[bool, int]]) –

    这控制了正向传播后的参数行为,可以在内存和通信之间进行权衡

    • 如果为 True,则在正向传播后重新分片参数并在反向传播中重新全部聚合。

    • 如果为 False,则在正向传播后将未分片参数保留在内存中,并避免在反向传播中进行全部聚合。为了获得最佳性能,我们通常将根模块设置为 False,因为反向传播开始时通常立即需要根模块。

    • 如果为 None,则对于非根模块设置为 True,对于根模块设置为 False

    • 如果为 int,则表示正向传播后重新分片到的世界大小。它应该是 mesh 分片维度的非平凡除数(即排除 1 和维度大小本身)。一个选择可能是节点内大小(例如 torch.cuda.device_count())。这允许反向传播中的全部聚合发生在较小的世界大小上,代价是比设置为 True 时更高的内存使用量。

    • 在正向传播之后,注册到模块的参数取决于此:如果为 True,则注册的参数是分片参数;如果为 False,则为未分片参数;否则为重新分片到较小网格的参数。要在正向传播和反向传播之间修改参数,注册的参数必须是分片参数。对于 Falseint,可以通过 reshard() 手动重新分片。

  • shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – 此可调用函数可用于覆盖参数的分片位置,以在非 dim-0 的维度上分片参数。如果此可调用函数返回 Shard 位置(非 None),则 FSDP 将根据该位置进行分片(例如 Shard(1))。如果在非零维度上分片,我们目前要求均匀分片,即该维度上的张量维度大小必须能被 FSDP 分片网格大小整除。

  • mp_policy (MixedPrecisionPolicy) – 这控制了混合精度策略,它为该模块提供了参数/缩减混合精度。有关详细信息,请参阅 MixedPrecisionPolicy

  • offload_policy (OffloadPolicy) – 这控制了卸载策略,它提供了参数/梯度/优化器状态卸载。有关详细信息,请参阅 OffloadPolicy 及其子类。

  • ignored_params (Optional[set[nn.Parameter]]) – Optional(Set[nn.Parameter]):FSDP 忽略的参数集。它们不会被分片,也不会在初始化期间移动到设备,也不会在反向传播中减少它们的梯度。

返回

应用了 FSDP 的模块(原地操作)。

返回类型

FSDPModule

class torch.distributed.fsdp.FSDPModule(*args, **kwargs)#
reshard()[source]#

重新分片模块的参数,如果已分配未分片参数则释放它们,并将分片参数注册到模块。此方法递归。

set_all_reduce_hook(hook, *, stream=None)[source]#
参数
  • hook (Callable[[torch.Tensor], None]) – 用户定义的全部聚合钩子,预期签名为 hook(reduce_output: torch.Tensor) -> None,其中 reduce_output 如果只使用 FSDP 则是缩减散布输出,如果使用原生 HSDP 则是全部聚合输出。

  • stream (Optional[torch.cuda.Stream]) – 运行全部聚合钩子的流。这只应在不使用原生 HSDP 时设置。如果使用原生 HSDP,钩子将在原生 HSDP 全部聚合使用的内部定义的全部聚合流中运行。

set_allocate_memory_from_process_group_for_comm(enable)[source]#

设置通过集体通信发送和接收数据时使用的临时暂存缓冲区是否应使用 ProcessGroup 本身提供的自定义优化分配器(如果有)进行分配。这可能允许 ProcessGroup 更高效。例如,当使用 NCCL 时,这使其能够利用 SHARP(对于 NVLink 和/或 InfiniBand)上的零拷贝传输。

参数

enable (bool) – 是否启用 ProcessGroup 分配。

set_force_sum_reduction_for_comms(enable)[source]#

设置是否要求底层集体通信原语专门使用“求和”类型缩减,即使这意味着需要单独的额外预缩放或后缩放操作。例如,NCCL 目前只支持这种集合的零拷贝传输。

注意:对于 MTIA 设备,这始终是隐式启用的。

注意:如果在 FSDP 设置下使用 set_all_reduce_hook,调用者需要确保跨 FSDP 单元的自定义全部聚合也遵循此策略,因为 FSDP 不再能自动处理。

参数

enable (bool) – 通信是否仅使用 ReduceOp.SUM。

set_gradient_divide_factor(factor)[source]#

为梯度缩减设置自定义除数。这可能使用 NCCL 的 PreMulSum 自定义缩减操作,它允许在缩减之前乘以因子。

参数

factor (float) – 自定义除数。

set_is_last_backward(is_last_backward)[source]#

设置下一次反向传播是否为最后一次。在最后一次反向传播时,FSDP 等待挂起的梯度缩减并清除内部数据结构以进行反向预取。这对于微批处理很有用。

set_modules_to_backward_prefetch(modules)[source]#

设置此 FSDP 模块应在反向传播中明确预取全部聚合的 FSDP 模块。这会覆盖默认的反向预取实现,该实现根据反向正向传播顺序预取下一个 FSDP 模块。

传递包含前一个 FSDP 模块的单例列表会产生与默认重叠行为相同的全部聚合重叠行为。对于更积极的重叠,需要传递至少两个元素的列表,并且将使用更多保留内存。

参数

modules (List[FSDPModule]) – 要预取的 FSDP 模块。

set_modules_to_forward_prefetch(modules)[source]#

设置此 FSDP 模块应在正向传播中明确预取全部聚合的 FSDP 模块。预取在此模块的全部聚合复制输出后运行。

传递包含下一个 FSDP 模块的单例列表会产生与默认重叠行为相同的全部聚合重叠行为,不同之处在于预取的全部聚合会更早地从 CPU 发出。对于更积极的重叠,需要传递至少两个元素的列表,并且将使用更多保留内存。

参数

modules (List[FSDPModule]) – 要预取的 FSDP 模块。

set_post_optim_event(event)[source]#

为根 FSDP 模块设置一个优化器步骤后事件,以等待全部聚合流。

默认情况下,根 FSDP 模块在当前流上等待全部聚合流,以确保优化器步骤在全部聚合之前完成。然而,如果优化器步骤之后有不相关的计算,这可能会引入错误依赖。此 API 允许用户提供自己的事件来等待。根模块等待事件后,事件将被丢弃,因此此 API 应在每次迭代时使用新事件调用。

参数

event (torch.Event) – 优化器步骤后记录的事件,用于等待全部聚合流。

set_reduce_scatter_divide_factor(factor)[source]#

请改用 set_gradient_divide_factor()

set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source]#

设置模块是否应全部聚合梯度。这可用于实现仅使用缩减散布而不全部聚合 HSDP 的梯度累积。

set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source]#

设置模块是否应同步梯度。这可用于实现不进行通信的梯度累积。对于 HSDP,这同时控制缩减散布和全部聚合。这相当于 FSDP1 中的 no_sync

参数
  • requires_gradient_sync (bool) – 是否减少模块参数的梯度。

  • recurse (bool) – 是否为所有 FSDP 子模块设置,或者仅为传入的模块设置。

set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source]#

设置模块是否应在反向传播后重新分片参数。这可以在梯度累积期间使用,以牺牲更高的内存来减少通信,因为未分片参数不需要在下一次正向传播之前重新全部聚合。

参数
  • reshard_after_backward (bool) – 是否在反向传播后重新分片参数。

  • recurse (bool) – 是否为所有 FSDP 子模块设置,或者仅为传入的模块设置。

set_reshard_after_forward(reshard_after_forward, recurse=True)[source]#

设置模块是否应在正向传播后重新分片参数。这可用于在运行时更改 reshard_after_forward FSDP 参数。例如,这可用于将 FSDP 根模块的值设置为 True(因为它通常被特别设置为 False),或者可以将 FSDP 模块的值设置为 False 以运行评估,然后设置回 True 以进行训练。

参数
  • reshard_after_forward (bool) – 是否在正向传播后重新分片参数。

  • recurse (bool) – 是否为所有 FSDP 子模块设置,或者仅为传入的模块设置。

set_unshard_in_backward(unshard_in_backward)[source]#

设置 FSDP 模块的参数是否需要在反向传播中取消分片。这可以在用户知道此 FSDP 模块的参数组中的所有参数都不需要进行反向计算(例如嵌入)的专家情况下使用。

unshard(async_op=False)[source]#

通过分配内存并全部聚合参数来取消分片模块的参数。此方法递归。取消分片遵循 MixedPrecisionPolicy,因此如果设置,它将按照 param_dtype 全部聚合。

参数

async_op (bool) – 如果为 True,则返回一个 UnshardHandle,它有一个 wait() 方法来等待取消分片操作。如果为 False,则返回 None 并在此函数内部等待句柄。

返回类型

Optional[UnshardHandle]

注意

如果 async_op=True,则 FSDP 将在模块的正向预处理中等待挂起的取消分片。用户只有在等待需要在正向预处理之前发生时才需要明确调用 wait()

class torch.distributed.fsdp.UnshardHandle#

一个等待 FSDPModule.unshard() 操作的句柄。

wait()[source]#

等待取消分片操作。这确保了当前流可以使用未分片参数,这些参数现在已注册到模块。

torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]#

注册 module 上的方法,使其被视为 FSDP 的正向方法。

FSDP 在正向传播前全部聚合参数,并可选择在正向传播后释放参数(取决于 reshard_after_forward)。默认情况下,FSDP 只知道对 nn.Module.forward() 执行此操作。此函数修补用户指定的方法,使其在方法之前/之后分别运行正向预/后钩子。如果 module 不是 FSDPModule,则此操作为空操作。

参数
  • module (nn.Module) – 要注册正向方法的模块。

  • method_name (str) – 正向方法的名称。

class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)#

这配置了 FSDP 的混合精度。与自动转换不同,这在模块级别而不是操作级别应用混合精度,这意味着低精度激活用于反向传播,高到低精度的转换仅发生在模块边界。

FSDP 与模块级混合精度配合良好,因为它无论如何都将高精度分片参数保留在内存中。换句话说,FSDP 不需要任何额外内存来为优化器步骤保留参数的高精度副本。

变量
  • param_dtype (Optional[torch.dtype]) – 这指定了未分片参数的数据类型,因此也指定了正向/反向计算和参数全部聚合的数据类型。如果这是 None,则未分片参数使用原始数据类型。优化器步骤使用原始数据类型中的分片参数。(默认值: None)

  • reduce_dtype (Optional[torch.dtype]) – 这指定了梯度缩减(即缩减散布或全部聚合)的数据类型。如果此值为 Noneparam_dtype 不为 None,则缩减使用计算数据类型。这可用于以全精度运行梯度缩减,同时使用低精度进行计算。如果还通过 set_requires_gradient_sync() 禁用了梯度缩减,则 FSDP 将使用 reduce_dtype 累积梯度。(默认值: None)

  • output_dtype (Optional[torch.dtype]) – 这指定了浮点正向输出的转换数据类型。这可用于帮助实现不同模块具有不同混合精度策略的情况。(默认值: None)

  • cast_forward_inputs (bool) – 这指定 FSDP 是否应将正向传播的浮点输入张量转换为 param_dtype

class torch.distributed.fsdp.OffloadPolicy#

此基类表示无卸载策略,仅用作 offload_policy 参数的默认值。

class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)#

此卸载策略将参数、梯度和优化器状态卸载到 CPU。分片参数在全部聚合之前从主机复制到设备。全部聚合的参数根据 reshard_after_forward 释放。分片梯度在反向传播中从设备复制到主机,优化器步骤在 CPU 上运行,使用 CPU 优化器状态。

变量

pin_memory (bool) – 是否固定分片参数和梯度内存。固定内存可以更高效地进行 H2D/D2H 复制,并允许复制与计算重叠。但是,固定内存不能被其他进程使用。如果 CPU 内存不足,请将其设置为 False。(默认值: True)