torch.distributed.fsdp.fully_shard#
创建时间:2024 年 12 月 04 日 | 最后更新时间:2025 年 06 月 16 日
PyTorch FSDP2(fully_shard
)#
PyTorch FSDP2(RFC)提供了一种完全分片数据并行(FSDP)实现,旨在提高易用性的每参数分片,同时实现高性能的动态模式。
fully_shard(model)
的用户协议如下:
在模型初始化时,`fully_shard` 会将 `model.parameters()` 从普通的 `torch.Tensor` 就地转换为 `DTensor`。参数会根据设备网格(device mesh)移动到相应的设备上。
在前向和后向传播之前,前向/后向预钩子(pre-forward/backward hooks)负责 all-gather 参数并将 `model.parameters()` 从 `DTensor` 转换回普通的 `torch.Tensor`。
在前向和后向传播之后,后向/后向钩子(post-forward/backward hooks)会释放未分片(unsharded)的参数(无需通信),并将 `model.parameters()` 从普通的 `torch.Tensor` 转换回 `DTensor`。
对于优化器,它必须使用 `DTensor` `model.parameters()` 进行初始化,并且优化器步骤应该在 `DTensor` 参数上执行。
调用
model(input)
而不是model.forward(input)
来触发预前向钩子(pre-forward hooks)进行参数的 all-gather。要使 `model.forward(input)` 工作,用户必须显式调用model.unshard()
,或者使用register_fsdp_forward_method(model, "forward")
来注册 forward 方法以便挂载钩子。`fully_shard` 将参数分组以进行单次 all-gather。用户应自底向上应用 `fully_shard`。例如,在 Transformer 模型中,`fully_shard` 应在应用于根模型之前应用于每个层。当应用于根模型时,`fully_shard` 会排除每个层中的 `model.parameters()`,并将剩余的参数(例如,嵌入层、输出投影层)分组到单个 all-gather 组中。
`type(model)` 会就地与 `FSDPModule` “联合”(unioned)。例如,如果模型最初是 `nn.Linear` 类型,那么 `fully_shard` 会就地将 `type(model)` 从 `nn.Linear` 更改为 `FSDPLinear`。`FSDPLinear` 是 `nn.Linear` 和 `FSDPModule` 的实例。它保留了 `nn.Linear` 的所有方法,同时还在 `FSDPModule` 下公开了 FSDP2 特定的 API,例如
reshard()
和unshard()
。参数的完全限定名(FQNs)保持不变。如果我们调用
model.state_dict()
,则在应用 `fully_shard` 前后 FQNs 相同。这是因为 `fully_shard` 不会包装模块,而只是向原始模块注册钩子。
与 PyTorch FSDP1(FullyShardedDataParallel
)相比
FSDP2 使用基于 `DTensor` 的 dim-0 每参数分片,相比 FSDP1 的平坦参数分片,分片表示更简单,同时保持了相似的吞吐量性能。更具体地说,FSDP2 在数据并行工作器上沿 dim-0 对每个参数进行分块(使用
torch.chunk(dim=0)
),而 FSDP1 则会将一组张量展平、连接和分块,使得推断每个工作器上存在哪些数据以及分片到不同并行性变得复杂。每参数分片提供了更直观的用户体验,放宽了对冻结参数的限制,并允许无通信(分片)的状态字典,而这在 FSDP1 中需要 all-gather。FSDP2 实现了一种不同的内存管理方法来处理多流使用,避免了
torch.Tensor.record_stream
。这确保了可预测且符合预期的内存使用,并且不需要像 FSDP1 的limit_all_gathers=True
那样阻塞 CPU。FSDP2 公开了用于手动控制预取和集合调度(collective scheduling)的 API,为高级用户提供了更多自定义选项。有关详细信息,请参阅下面的
FSDPModule
方法。FSDP2 简化了一些 API 表面:例如,FSDP2 不直接支持完整的状态字典。相反,用户可以使用
DTensor
API(如DTensor.full_tensor()
)或使用更高级别的 API(如 PyTorch Distributed Checkpoint 的分布式状态字典 API)将包含DTensor
的分片状态字典重塑为完整的状态字典。此外,还移除了一些其他参数;有关详细信息,请参阅 此处。
前端 API 是 fully_shard
,可以对 module
调用。
- torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=None, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy(), ignored_params=None)[source]#
对 `module` 应用完全分片数据并行(FSDP),其中 FSDP 将模块参数、梯度和优化器状态分片到数据并行工作器上,以节省内存,代价是通信开销。
在初始化时,FSDP 会根据 `mesh` 给出的数据并行工作器对模块的参数进行分片。在前向传播之前,FSDP 会将分片后的参数 all-gather 到数据并行工作器上,以获取用于前向计算的未分片参数。如果 `reshard_after_forward` 为 `True`,则 FSDP 会在前向传播后释放未分片参数,并在后向传播中重新 all-gather 它们以进行梯度计算。梯度计算之后,FSDP 会释放未分片参数,并将未分片梯度 reduce-scatter 到数据并行工作器上。
此实现将分片参数表示为在 dim-0 上分片的 `DTensor`,而未分片参数将类似于 `module` 上的原始参数(例如,如果最初是 `
torch.Tensor
`,则为 `torch.Tensor
`)。`module` 上的一个前向预钩子(forward pre-hook)将参数 all-gather,而一个前向钩子(forward hook)将参数(如果需要)释放。类似的后向钩子将参数 all-gather,然后释放参数并 reduce-scatter 梯度。由于将多个张量组合成一次集体通信(collective)对于通信效率至关重要,因此此实现将这种组合作为首要考虑。对 `module` 调用 `
fully_shard()
会构造一个组,该组包含 `module.parameters()` 中的参数,但排除从早期子模块调用中已分配到组的参数。这意味着 `fully_shard()
应在模型上自底向上调用。每个组的参数在一个集体通信中 all-gather,其梯度在一个集体通信中 reduce-scatter。将模型划分为多个组(“逐层”)可以实现最大的内存节省和计算/通信重叠。用户通常 *不* 应该只在最顶层的根模块上调用 `fully_shard()
`。- 参数
module (Union[nn.Module, List[nn.Module]) – 要用 FSDP 分片并分组进行通信的模块或模块列表。
mesh (Optional[DeviceMesh]) – 此数据并行网格定义了分片和设备。如果是一维的,则参数在整个一维网格上完全分片(FSDP),放置策略为
(Shard(0),)
。如果是二维的,则参数在第一维上分片,在第 0 维上复制(HSDP),放置策略为(Replicate(), Shard(0))
。网格的设备类型决定了通信使用的设备类型;如果是 CUDA 或类似 CUDA 的设备类型,则使用当前设备。reshard_after_forward (Optional[Union[bool, int]]) –
这控制了前向传播后参数的行为,可以权衡内存和通信。
如果为
True
,则在后向传播时重新 all-gather 参数。如果为
False
,则在前向传播后将未分片参数保留在内存中,并避免后向传播中的 all-gather。为了获得最佳性能,我们通常为根模块设置False
,因为根模块在后向传播开始时通常是必需的。如果为
None
,则非根模块设置为True
,根模块设置为False
。如果为
int
,则表示后向传播后重塑的目标世界大小。它应该是 `mesh` 分片维度大小的一个非平凡除数(即不包括 1 和维度本身)。一个可能的选择是节点内大小(例如torch.cuda.device_count()
)。这使得后向传播中的 all-gather 可以使用更小的世界大小,但内存使用量比设置为True
时更高。在前向传播之后,注册到模块的参数取决于此:如果为
True
,注册的参数是分片参数;如果为False
,则是未分片参数;否则是重塑到更小网格的参数。要在前向和后向传播之间修改参数,注册的参数必须是分片参数。对于False
或int
,可以通过手动调用reshard()
来实现。
shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – 此可调用对象可用于覆盖参数的分片放置,以便在 dim-0 以外的维度上分片参数。如果此可调用对象返回一个
Shard
放置(不是None
),则 FSDP 将根据该放置进行分片(例如Shard(1)
)。如果沿非零维度分片,我们目前要求均匀分片,即该维度的张量尺寸必须能被 FSDP 分片网格大小整除。mp_policy (MixedPrecisionPolicy) – 这控制混合精度策略,为该模块提供参数/归约混合精度。有关详细信息,请参阅
MixedPrecisionPolicy
。offload_policy (OffloadPolicy) – 这控制卸载策略,提供参数/梯度/优化器状态卸载。有关详细信息,请参阅
OffloadPolicy
及其子类。ignored_params (Optional[set[nn.Parameter]]) – 可选(Set[nn.Parameter]):要被 FSDP 忽略的参数集合。它们不会被分片,也不会在初始化时移动到设备,也不会在后向传播中减少梯度。
- 返回
应用 FSDP 后的模块(就地)。
- 返回类型
- class torch.distributed.fsdp.FSDPModule(*args, **kwargs)#
-
- set_all_reduce_hook(hook, *, stream=None)[source]#
- 参数
hook (Callable[[torch.Tensor], None]) – 用户定义的 all-reduce 钩子,预期签名如下:
hook(reduce_output: torch.Tensor) -> None
,其中reduce_output
是仅使用 FSDP 时的 reduce-scatter 输出,或使用原生 HSDP 时的 all-reduce 输出。stream (Optional[torch.cuda.Stream]) – 在其中运行 all-reduce 钩子的流。仅在不使用原生 HSDP 时设置。如果使用原生 HSDP,钩子将在原生 HSDP all-reduce 使用的内部定义的 all-reduce 流中运行。
- set_allocate_memory_from_process_group_for_comm(enable)[source]#
设置用于通过集体通信发送和接收数据的临时暂存缓冲区(staging buffers)是否应使用进程组(ProcessGroup)本身提供的自定义优化分配器(如果可用)。这可能使进程组更有效。例如,在使用 NCCL 时,这使得它可以利用 SHARP 的零拷贝传输(针对 NVLink 和/或 InfiniBand)。
此选项不能与
set_custom_all_gather()
或set_custom_reduce_scatter()
一起使用,因为这些 API 允许对每个通信进行更细粒度的控制,而此方法无法确定它们的暂存缓冲区分配策略。- 参数
enable (bool) – 是否启用进程组分配。
- set_custom_all_gather(comm)[source]#
覆盖默认的
all_gather
通信行为,以更好地控制通信和内存使用。有关详细信息,请参阅 Comm 和 ReduceScatter。- 参数
comm (AllGather) – 自定义 all-gather 通信。
- set_custom_reduce_scatter(comm)[source]#
覆盖默认的
reduce_scatter
通信行为,以更好地控制通信和内存使用。有关详细信息,请参阅 Comm 和 ReduceScatter。- 参数
comm (ReduceScatter) – 自定义 reduce_scatter 通信。
- set_force_sum_reduction_for_comms(enable)[source]#
设置底层集体通信原语是否必须仅使用“sum”类型的归约(reduction),即使这需要额外的预处理或后处理缩放操作。例如,这对于 NCCL 支持对这类集体通信进行零拷贝传输是必需的。
注意:对于 MTIA 设备,此选项始终隐式启用。
注意:如果 FSDP 设置下使用了 `set_all_reduce_hook`,则调用者需要确保 FSDP 单元之间的自定义 all-reduce 也遵循此策略,因为 FSDP 已无法自动处理。
- 参数
enable (bool) – 是否仅为通信使用 `ReduceOp.SUM`。
- set_gradient_divide_factor(factor)[source]#
设置梯度归约的自定义除法因子。这可能使用 NCCL 的 `PreMulSum` 进行自定义归约操作,该操作允许在归约之前乘以该因子。
- 参数
factor (float) – 自定义除法因子。
- set_is_last_backward(is_last_backward)[source]#
设置下一个后向传播是否是最后一个。在最后一个后向传播时,FSDP 会等待待处理的梯度归约,并清除用于后向预取的内部数据结构。这对于微批次(microbatching)非常有用。
- set_modules_to_backward_prefetch(modules)[source]#
设置 FSDP 模块,这些模块应该在后向传播中显式预取 all-gather。这会覆盖默认的后向预取实现,该实现会根据反向后前向顺序(reverse post-forward order)预取下一个 FSDP 模块。
传递包含前一个 FSDP 模块的单例列表会产生与默认重叠行为相同的 all-gather 重叠行为。传递至少长度为二的列表需要更激进的重叠,并且会占用更多预留内存。
- 参数
modules (List[FSDPModule]) – 要预取的 FSDP 模块。
- set_modules_to_forward_prefetch(modules)[source]#
设置 FSDP 模块,这些模块应该在正向传播中显式预取 all-gather。预取在模块的 all-gather 复制(copy-out)之后运行。
传递包含下一个 FSDP 模块的单例列表会产生与默认重叠行为相同的 all-gather 重叠行为,只是预取的 all-gather 会更早地从 CPU 发出。传递至少长度为二的列表需要更激进的重叠,并且会占用更多预留内存。
- 参数
modules (List[FSDPModule]) – 要预取的 FSDP 模块。
- set_post_optim_event(event)[source]#
为根 FSDP 模块设置一个后优化器步骤事件(post-optimizer-step event),以便 all-gather 流在此事件上等待。
默认情况下,根 FSDP 模块在当前流上等待 all-gather 流,以确保优化器步骤在 all-gather 之前完成。然而,这可能会引入假依赖(false dependencies),如果优化器步骤之后有不相关的计算。此 API 允许用户提供自己的事件来等待。在根模块等待事件后,该事件将被丢弃,因此应在每次迭代中都调用此 API 并提供一个新事件。
- 参数
event (torch.Event) – 优化器步骤后记录的事件,all-gather 流将在此事件上等待。
- set_reduce_scatter_divide_factor(factor)[source]#
请使用
set_gradient_divide_factor()
代替。
- set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source]#
设置模块是否应 all-reduce 梯度。这可用于实现梯度累积,对于 HSDP 仅使用 reduce-scatter 而不使用 all-reduce。
- set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source]#
设置模块是否应同步梯度。这可用于实现 *无需通信* 的梯度累积。对于 HSDP,这同时控制 reduce-scatter 和 all-reduce。这是 FSDP1 中 no_sync 的等价形式。
- set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source]#
设置模块在后向传播后是否应重塑参数。这可以在梯度累积期间使用,以通过增加内存消耗来换取通信减少,因为在下一次前向传播之前不需要重新 all-gather 未分片参数。
- set_reshard_after_forward(reshard_after_forward, recurse=True)[source]#
设置模块在前向传播后是否应重塑参数。这可用于在运行时更改 `reshard_after_forward` FSDP 参数。例如,这可用于将 FSDP 根模块的值设置为
True
(因为它否则被特殊设置为False
),或者将其设置为False
以运行评估,并在训练时重新设置为True
。
- set_unshard_in_backward(unshard_in_backward)[source]#
设置 FSDP 模块的参数是否需要在后向传播中进行 unsharded。这可用于专家场景,当用户知道该 FSDP 模块的参数组中的所有参数在后向传播计算中都不需要时(例如,嵌入层)。
- unshard(async_op=False)[source]#
通过分配内存并 all-gather 参数来 unsharded 模块的参数。此方法 *不是* 递归的。unshard 操作遵循 `
MixedPrecisionPolicy
`,因此如果设置了 `param_dtype`,它将按照 `param_dtype` 进行 all-gather。- 参数
async_op (bool) – 如果为
True
,则返回一个具有 `wait()` 方法的 `UnshardHandle
` 以等待 unshard 操作。如果为False
,则返回None
并在函数内部等待句柄。- 返回类型
注意
如果 `async_op=True`,则 FSDP 将在用户的前向传播之前等待待处理的 unshard 操作。只有当等待应在前向传播之前发生时,用户才需要显式调用 `
wait()
`。
- class torch.distributed.fsdp.UnshardHandle#
用于等待 `
FSDPModule.unshard()
操作的句柄。
- torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]#
在 `module` 上注册一个方法,以将其视为 FSDP 的前向方法。
FSDP 在前向传播之前 all-gather 参数,并在前向传播之后(取决于 `reshard_after_forward`)可选地释放参数。FSDP 默认只知道对 `nn.Module.forward()` 执行此操作。此函数会修补用户指定的方法,以便在前向/后向方法之前/之后分别运行预/后前向钩子。如果 `module` 不是 `
FSDPModule
`,则此函数无操作。
- class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)#
这配置了 FSDP 的混合精度。与 autocast 不同,这在模块级别而不是操作级别应用混合精度,这意味着低精度激活会为后向传播保存,而高到低精度的转换仅在模块边界发生。
FSDP 与模块级别的混合精度配合良好,因为它仍然将高精度分片参数保存在内存中。换句话说,FSDP 不需要额外的内存来保留参数的高精度副本以供优化器步骤使用。
- 变量
param_dtype (Optional[torch.dtype]) – 这指定了未分片参数的数据类型,因此是前向/后向计算和参数 all-gather 的数据类型。如果为 `None`,则未分片参数使用原始数据类型。优化器步骤使用原始数据类型中的分片参数。(默认:
None
)reduce_dtype (Optional[torch.dtype]) – 这指定了梯度归约(即 reduce-scatter 或 all-reduce)的数据类型。如果为 `None` 但 `param_dtype` 不是 `None`,则归约使用计算数据类型。这可用于在使用低精度进行计算的同时,以全精度执行梯度归约。如果还通过 `
set_requires_gradient_sync()
` 禁用了梯度归约,则 FSDP 将使用 `reduce_dtype` 累积梯度。(默认:None
)output_dtype (Optional[torch.dtype]) – 这指定了用于转换浮点前向输出的数据类型。这可用于帮助实现不同模块具有不同混合精度策略的情况。(默认:
None
)cast_forward_inputs (bool) – 这指定了 FSDP 是否应将前向传播的浮点输入张量转换为 `param_dtype`。
- class torch.distributed.fsdp.OffloadPolicy#
这个基类表示无卸载策略,仅用作 `
offload_policy
` 参数的默认值。
- class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)#
此卸载策略会将参数、梯度和优化器状态卸载到 CPU。分片参数在 all-gather 之前复制到主机-设备(host-to-device)。all-gather 的参数根据 `reshard_after_forward` 进行释放。分片梯度在后向传播时从设备复制到主机(device-to-host),优化器步骤在 CPU 上使用 CPU 优化器状态运行。
- 变量
pin_memory (bool) – 是否固定分片参数和梯度内存。固定内存可以实现更高效的主机到设备/设备到主机(H2D/D2H)复制,并允许复制与计算重叠。但是,固定的内存不能被其他进程使用。如果 CPU 内存不足,请将其设置为
False
。(默认:True
)