评价此页

torch.distributed.fsdp.fully_shard#

创建日期:2024 年 12 月 04 日 | 最后更新:2025 年 10 月 13 日

PyTorch FSDP2 (fully_shard)#

PyTorch FSDP2 (RFC) 提供了一种全分片数据并行 (FSDP) 实现,旨在实现高性能即时模式 (eager-mode),同时使用按参数分片来提高易用性。

  • 有关详细信息,请参阅 FSDP2 入门教程。

  • 如果您目前正在使用 FSDP1,请考虑使用我们的迁移指南迁移到 FSDP2。

fully_shard(model) 的用户协议如下:

  • 对于模型初始化,fully_shard 会原地将 model.parameters() 从普通的 torch.Tensor 转换为 DTensor。根据设备网格 (device mesh),参数会被移动到相应的设备上。

  • 在正向和反向传播之前,前向/反向预处理钩子 (pre-forward/backward hooks) 负责执行 all-gather 以获取参数,并将 model.parameters() 从 DTensor 转换回普通的 torch.Tensor。

  • 在正向和反向传播之后,后向钩子 (post-forward/backward hooks) 会释放非分片参数(无需通信),并将 model.parameters() 从普通 torch.Tensor 转回 DTensor。

  • 对于优化器,必须使用 DTensor 模型的 model.parameters() 进行初始化,且优化器步骤应在 DTensor 参数上执行。

  • 调用 model(input) 而不是 model.forward(input) 来触发前向预处理钩子以 all-gather 参数。为了使 model.forward(input) 正常工作,用户必须显式调用 model.unshard(),或者使用 register_fsdp_forward_method(model, "forward") 为该前向方法注册钩子。

  • fully_shard 将参数分组以进行单次 all-gather 操作。用户应以自下而上的方式应用 fully_shard。例如,在 Transformer 模型中,应在对根模型应用 fully_shard 之前,先对每一层应用 fully_shard。当应用于根模型时,fully_shard 会排除每一层的 model.parameters(),并将剩余参数(如 embeddings、输出投影层)分组到一个 all-gather 组中。

  • type(model) 会原地与 FSDPModule 进行“联合”。例如,如果模型最初是 nn.Linear 类型,fully_shard 会将 type(model) 从 nn.Linear 原地修改为 FSDPLinearFSDPLinear 同时是 nn.Linear 和 FSDPModule 的实例。它保留了 nn.Linear 的所有方法,同时在 FSDPModule 下公开了 FSDP2 特定的 API,例如 reshard()unshard()

  • 参数的全限定名 (FQN) 保持不变。如果我们调用 model.state_dict(),应用 fully_shard 前后的 FQN 是一样的。这是因为 fully_shard 不会包装模块,只会将钩子注册到原始模块上。

与 PyTorch FSDP1 (FullyShardedDataParallel) 相比

  • FSDP2 使用基于 DTensor 的维度 0 按参数分片,与 FSDP1 的平坦参数分片相比,它提供了更简单的分片表示,同时保持了相似的吞吐性能。具体来说,FSDP2 在数据并行工作进程上对每个参数在维度 0 上进行切片(使用 torch.chunk(dim=0)),而 FSDP1 会将一组张量扁平化、拼接并整体切片,这使得推理每个工作进程上存在哪些数据以及重分片到不同的并行方案变得复杂。按参数分片提供了更直观的用户体验,放宽了对冻结参数的约束,并允许使用无需通信的(分片)状态字典,而在 FSDP1 中,这些通常需要 all-gather。

  • FSDP2 实现了不同的内存管理方法来处理多流使用,避免了 torch.Tensor.record_stream。这确保了确定性和预期的内存使用,且不像 FSDP1 的 limit_all_gathers=True 那样需要阻塞 CPU。

  • FSDP2 公开了用于手动控制预取 (prefetching) 和集合通信调度的 API,允许高级用户进行更多自定义。有关详细信息,请参阅下文 FSDPModule 的方法。

  • FSDP2 简化了一些 API 表面:例如,FSDP2 不直接支持完整状态字典 (full state dicts)。相反,用户可以使用 DTensor API(如 DTensor.full_tensor())或使用更高级别的 API(如 PyTorch 分布式检查点的分布式状态字典 API)自行将包含 DTensor 的分片状态字典重分片为完整状态字典。此外,一些其他参数已被删除;详情请见此处

前端 API 是 fully_shard,可以在 module 上调用

torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=None, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy(), ignored_params=None)[源码]#

将全分片数据并行 (FSDP) 应用于 module,其中 FSDP 将模块参数、梯度和优化器状态分片到数据并行工作进程上,以节省内存,代价是增加通信开销。

在初始化时,FSDP 根据 mesh 将模块参数分片到数据并行工作进程上。在前向传播之前,FSDP 对分片参数执行 all-gather,以获取前向计算所需的非分片参数。如果 reshard_after_forwardTrue,则 FSDP 在前向传播后释放非分片参数,并在反向传播计算梯度之前重新进行 all-gather。梯度计算后,FSDP 释放非分片参数,并在数据并行工作进程上对非分片梯度执行 reduce-scatter。

此实现将分片参数表示为在维度 0 上分片的 DTensor,而非分片参数则与 module 上的原始参数保持一致(例如,如果原始是 torch.Tensor,则此处为 torch.Tensor)。module 上的 前向预处理钩子 会 all-gather 参数,而 module 上的 前向钩子 会释放它们(如果需要)。类似的反向钩子会 all-gather 参数,并在稍后释放参数并对梯度执行 reduce-scatter。

由于将多个张量组合在一起进行一次集合通信对通信效率至关重要,此实现将这种分组作为一等公民对待。在 module 上调用 fully_shard() 会构造一个组,其中包括 module.parameters() 中的参数,但不包括已经在较早的子模块调用中分配给其他组的参数。这意味着 fully_shard() 应在您的模型上自下而上地调用。每个组的参数在一次集合通信中 all-gather,其梯度在一次集合通信中 reduce-scatter。将模型划分为多个组(“逐层”)可以实现最佳的内存节省以及通信与计算的重叠。用户通常应该仅在最顶层的根模块上调用 fully_shard()

参数:
  • module (Union[nn.Module, List[nn.Module]) – 要使用 FSDP 分片并分组进行通信的模块或模块列表。

  • mesh (Optional[DeviceMesh]) – 此数据并行网格定义了分片方式和设备。如果是一维的,则参数完全分片在 1D 网格 (FSDP) 上,放置方式为 (Shard(0),)。如果是二维的,则参数在第 1 维上分片,并在第 0 维上复制 (HSDP),放置方式为 (Replicate(), Shard(0))。网格的设备类型决定了用于通信的设备类型;如果是 CUDA 或类 CUDA 设备,则我们使用当前设备。

  • reshard_after_forward (Optional[Union[bool, int]]) –

    这控制了前向传播后的参数行为,可以在内存和通信之间进行权衡。

    • 如果为 True,则在前向传播后重新分片参数,并在反向传播中重新 all-gather。

    • 如果为 False,则在前向传播后将非分片参数保存在内存中,并避免反向传播中的 all-gather。为了获得最佳性能,我们通常将根模块设置为 False,因为根模块通常在反向传播开始时就需要。

    • 如果为 None,则非根模块默认为 True,根模块默认为 False

    • 如果为 int,则表示前向传播后要重分片到的世界大小 (world size)。它应该是 mesh 分片维度大小的一个非平凡除数(即排除 1 和维度大小本身)。一种选择可以是节点内大小(例如 torch.cuda.device_count())。这允许反向传播中的 all-gather 在较小的世界大小上进行,代价是比设置为 True 时更高的内存使用量。

    • 前向传播后,注册到模块的参数取决于此项:如果为 True,则注册的参数为分片参数;如果为 False,则为非分片参数;否则为重分片到较小网格的参数。若要在前向和反向传播之间修改参数,注册的参数必须是分片参数。对于 Falseint 情况,可以通过 reshard() 手动重新分片来实现。

  • shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – 此可调用对象可用于覆盖参数的分片放置方式,以在除维度 0 之外的维度上分片参数。如果此可调用对象返回 Shard 放置方式(非 None),则 FSDP 将根据该放置方式(例如 Shard(1))进行分片。如果在非零维度上进行分片,我们目前要求均匀分片,即该维度上的张量大小必须能被 FSDP 分片网格大小整除。

  • mp_policy (MixedPrecisionPolicy) – 控制混合精度策略,为该模块提供参数/规约混合精度。详情请参阅 MixedPrecisionPolicy

  • offload_policy (OffloadPolicy) – 控制卸载策略,提供参数/梯度/优化器状态卸载。详情请参阅 OffloadPolicy 及其子类。

  • ignored_params (set[nn.Parameter] | None) – 可选 (Set[nn.Parameter]):将被 FSDP 忽略的参数集合。它们既不会被分片,也不会在初始化期间移动到设备,反向传播时其梯度也不会被规约。

返回:

应用了 FSDP 的模块(原地修改)。

返回类型:

FSDPModule

class torch.distributed.fsdp.FSDPModule(*args, **kwargs)#
reshard()[源码]#

重新分片模块的参数,释放已分配的非分片参数,并将分片参数注册到模块。此方法不是递归的。

set_all_reduce_hook(hook, *, stream=None)[源码]#
参数:
  • hook (Callable[[torch.Tensor], None]) – 用户定义的 all-reduce 钩子,预期签名为 hook(reduce_output: torch.Tensor) -> None,其中 reduce_output 是仅使用 FSDP 时的 reduce-scatter 输出,或者是使用原生 HSDP 时的 all-reduce 输出。

  • stream (Optional[torch.cuda.Stream]) – 运行 all-reduce 钩子的流。仅在不使用原生 HSDP 时才应设置此项。如果使用原生 HSDP,钩子将在原生 HSDP all-reduce 所使用的内部定义的 all-reduce 流中运行。

set_allocate_memory_from_process_group_for_comm(enable)[源码]#

设置是否应使用 ProcessGroup 本身提供的自定义优化分配器(如果有)来分配用于通过集合通信发送和接收数据的临时分段缓冲区 (staging buffers)。这可能使 ProcessGroup 更有效率。例如,使用 NCCL 时,这使其能够利用 SHARP(针对 NVLink 和/或 InfiniBand)进行零拷贝传输。

此方法不能与 set_custom_all_gather()set_custom_reduce_scatter() 一起使用,因为这些 API 允许对每次通信进行更精细的控制,而此方法无法确定它们的缓冲区分配策略。

参数:

enable (bool) – 是否开启 ProcessGroup 分配。

set_custom_all_gather(comm)[源码]#

覆盖默认的 all_gather 通信行为,以更好地控制通信和内存使用。详情请参阅 CommReduceScatter

参数:

comm (AllGather) – 自定义 all-gather 通信。

set_custom_reduce_scatter(comm)[源码]#

覆盖默认的 reduce_scatter 通信行为,以更好地控制通信和内存使用。详情请参阅 CommReduceScatter

参数:

comm (ReduceScatter) – 自定义 reduce_scatter 通信。

set_force_sum_reduction_for_comms(enable)[源码]#

设置是否要求底层集合通信原语显式仅使用“求和 (sum)”类型的规约,即使这需要额外的缩放操作。例如,这是必要的,因为 NCCL 目前仅为此类集合通信支持零拷贝传输。

注意:对于 MTIA 设备,此功能始终是隐式启用的。

注意:如果 FSDP 设置下使用了 set_all_reduce_hook,调用者需要确保 FSDP 单元之间的自定义 all-reduce 也遵循此策略,因为 FSDP 无法再自动处理它。

参数:

enable (bool) – 是否仅使用 ReduceOp.SUM 进行通信。

set_gradient_divide_factor(factor)[源码]#

设置用于梯度规约的自定义除数因子。这可以使用 NCCL 的 PreMulSum 使用自定义规约操作,允许在规约前乘以该因子。

参数:

factor (float) – 自定义除数因子。

set_is_last_backward(is_last_backward)[源码]#

设置下一次反向传播是否为最后一次。在最后一次反向传播时,FSDP 等待挂起的梯度规约,并清除用于反向预取的内部数据结构。这对微批次 (microbatching) 很有用。

set_modules_to_backward_prefetch(modules)[源码]#

设置此 FSDP 模块应在反向传播中显式预取 all-gathers 的 FSDP 模块列表。这会覆盖默认的反向预取实现(默认基于前向传播顺序的逆序预取下一个 FSDP 模块)。

传递包含上一个 FSDP 模块的单元素列表可获得与默认重叠行为相同的 all-gather 重叠。要获得更积极的重叠,需要传递至少长度为 2 的列表,这将使用更多保留内存。

参数:

modules (List[FSDPModule]) – 要预取的 FSDP 模块。

set_modules_to_forward_prefetch(modules)[源码]#

设置此 FSDP 模块应在前向传播中显式预取 all-gathers 的 FSDP 模块列表。预取在该模块的 all-gather 拷贝操作完成后运行。

传递包含下一个 FSDP 模块的单元素列表可获得与默认重叠行为相同的 all-gather 重叠,只是预取的 all-gather 会从 CPU 更早发出。要获得更积极的重叠,需要传递至少长度为 2 的列表,这将使用更多保留内存。

参数:

modules (List[FSDPModule]) – 要预取的 FSDP 模块。

set_post_optim_event(event)[源码]#

设置根 FSDP 模块在 all-gather 流中等待的后优化步骤事件。

默认情况下,根 FSDP 模块会在当前流上等待 all-gather 流,以确保在 all-gather 之前优化步骤已完成。但是,如果优化步骤后有不相关的计算,这可能会引入错误的依赖关系。此 API 允许用户提供他们自己的事件来等待。在根模块等待该事件后,该事件将被丢弃,因此此 API 应在每次迭代中调用并提供一个新的事件。

参数:

event (torch.Event) – 优化步骤后记录的事件,用于 all-gather 流等待。

set_reduce_scatter_divide_factor(factor)[源码]#

请改用 set_gradient_divide_factor()

set_requires_all_reduce(requires_all_reduce, *, recurse=True)[源码]#

设置模块是否应 all-reduce 梯度。这可用于实现梯度累积,对于 HSDP,仅执行 reduce-scatter 而不执行 all-reduce。

set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[源码]#

设置模块是否应同步梯度。这可用于实现无需通信的梯度累积。对于 HSDP,这同时控制 reduce-scatter 和 all-reduce。这相当于 FSDP1 中的 no_sync

参数:
  • requires_gradient_sync (bool) – 是否为模块的参数规约梯度。

  • recurse (bool) – 是否为所有 FSDP 子模块设置,还是仅为传入的模块设置。

set_reshard_after_backward(reshard_after_backward, *, recurse=True)[源码]#

设置模块是否应在反向传播后重分片参数。这可以在梯度累积期间使用,以用更多的内存换取减少的通信,因为非分片参数在下一次前向传播之前不需要重新 all-gather。

参数:
  • reshard_after_backward (bool) – 是否在反向传播后重分片参数。

  • recurse (bool) – 是否为所有 FSDP 子模块设置,还是仅为传入的模块设置。

set_reshard_after_forward(reshard_after_forward, recurse=True)[源码]#

设置模块是否应在前向传播后重分片参数。这可用于在运行时更改 reshard_after_forward FSDP 参数。例如,这可用于将 FSDP 根模块的值设置为 True(因为默认情况下它被特殊设置为 False),或者用于将 FSDP 模块的值设置为 False 以进行评估,并再设置为 True 以进行训练。

参数:
  • reshard_after_forward (bool) – 是否在前向传播后重分片参数。

  • recurse (bool) – 是否为所有 FSDP 子模块设置,还是仅为传入的模块设置。

set_unshard_in_backward(unshard_in_backward)[源码]#

设置 FSDP 模块的参数是否需要在反向传播中解分片 (unshard)。这可用于专家场景,当用户知道该 FSDP 模块参数组中的所有参数在反向传播计算中都不需要时(例如嵌入层)。

unshard(async_op=False)[源码]#

通过分配内存并执行全收集(all-gather)操作来解分片(unshard)模块参数。此方法递归。解分片遵循 MixedPrecisionPolicy,因此如果设置了 param_dtype,它将按照该数据类型执行全收集。

参数:

async_op (bool) – 如果为 True,则返回一个 UnshardHandle,该对象具有 wait() 方法来等待解分片操作完成。如果为 False,则返回 None 并在函数内部等待句柄完成。

返回类型:

UnshardHandle | None

注意

如果 async_op=True,FSDP 将在模块的前向传播前(pre-forward)自动等待挂起的解分片操作。只有当用户需要在前向传播前显式等待时,才需要调用 wait()

class torch.distributed.fsdp.UnshardHandle#

用于等待 FSDPModule.unshard() 操作的句柄。

wait()[source]#

等待解分片操作完成。这确保当前流可以使用已解分片的参数,这些参数现已注册到模块中。

torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]#

module 上的某个方法注册为 FSDP 的前向传播方法。

FSDP 会在前向传播前全收集参数,并根据 reshard_after_forward 的设置,选择性地在前向传播后释放参数。默认情况下,FSDP 仅知道对 nn.Module.forward() 执行此操作。此函数通过修补用户指定的方法,使其在方法执行前后分别运行前向传播的前/后钩子(hooks)。如果 module 不是 FSDPModule,则此函数为空操作。

参数:
  • module (nn.Module) – 要注册前向传播方法的模块。

  • method_name (str) – 前向传播方法的名称。

class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)#

该类用于配置 FSDP 的混合精度。与 autocast 不同,此策略在模块级别而非算子级别应用混合精度,这意味着反向传播时会保存低精度激活值,且高到低精度的类型转换仅发生在模块边界处。

FSDP 与模块级混合精度配合良好,因为它始终将高精度分片参数保留在内存中。换句话说,FSDP 不需要额外的内存来为优化器步骤保留一份高精度参数副本。

变量:
  • param_dtype (Optional[torch.dtype]) – 指定用于解分片参数的数据类型,从而确定前向/反向传播计算及参数全收集的数据类型。如果为 None,则解分片参数使用原始数据类型。优化器步骤使用原始数据类型的分片参数。(默认值:None

  • reduce_dtype (Optional[torch.dtype]) – 指定用于梯度规约(即 reduce-scatter 或 all-reduce)的数据类型。如果为 Noneparam_dtype 不为 None,则规约使用计算数据类型。这可用于在低精度计算的同时以全精度运行梯度规约。如果通过 set_requires_gradient_sync() 禁用了梯度规约,FSDP 将使用 reduce_dtype 累积梯度。(默认值:None

  • output_dtype (Optional[torch.dtype]) – 指定用于转换浮点前向输出的数据类型。这可用于实现不同模块具有不同混合精度策略的情况。(默认值:None

  • cast_forward_inputs (bool) – 指定 FSDP 是否应将前向传播的浮点输入张量转换为 param_dtype

class torch.distributed.fsdp.OffloadPolicy#

该基类代表不卸载的策略,仅用作 offload_policy 参数的默认值。

class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)#

此卸载策略将参数、梯度和优化器状态卸载到 CPU。分片参数在全收集之前从主机拷贝到设备。全收集后的参数会根据 reshard_after_forward 进行释放。分片梯度在反向传播期间从设备拷贝到主机,优化器步骤在 CPU 上使用 CPU 优化器状态运行。

变量:

pin_memory (bool) – 是否锁定(pin)分片参数和梯度的内存。锁定内存不仅能提高 H2D/D2H 拷贝效率,还能使拷贝过程与计算重叠。但需注意,锁定的内存不能被其他进程使用。如果 CPU 内存不足,请将其设置为 False。(默认值:True

torch.distributed.fsdp.share_comm_ctx(modules)[source]#

为多个 FSDPModule 共享 CUDA 流。

使用示例

from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([fsdp_model_1, fsdp_model_2, …])

对于流水线并行(PP),每个模型分块都是一个 FSDP 根节点。我们需要为 all-gather、reduce-scatter 和 all-reduce 共享 CUDA 流,以避免分配流间内存碎片。

参数:

modules (List[FSDPModule]) – 需要共享 CUDA 流的模块。