FullyShardedDataParallel#
创建日期:2022年2月2日 | 最后更新日期:2025年6月11日
- class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[源码]#
用于在数据并行工作进程间对模块参数进行分片(Sharding)的包装器。
此实现灵感来源于 Xu 等人 以及 DeepSpeed 中的 ZeRO Stage 3。FullyShardedDataParallel 通常简称为 FSDP。
示例
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
使用 FSDP 时,需要先包装你的模块,然后再初始化优化器。这是必须的,因为 FSDP 会改变参数变量。
设置 FSDP 时,你需要考虑目标 CUDA 设备。如果设备具有 ID(
dev_id),你有三种选择:将模块放置在该设备上
使用
torch.cuda.set_device(dev_id)设置设备将
dev_id传递给构造函数参数device_id。
这能确保 FSDP 实例的计算设备是目标设备。对于选项 1 和 3,FSDP 初始化始终在 GPU 上发生。对于选项 2,FSDP 初始化发生在模块当前的设备上,这可能是 CPU。
如果你使用
sync_module_states=True标志,则需要确保模块位于 GPU 上,或者使用device_id参数指定一个 CUDA 设备,以便 FSDP 在构造函数中将模块移动过去。这是必要的,因为sync_module_states=True需要 GPU 通信。FSDP 还会负责将输入张量移动到前向传播方法的 GPU 计算设备上,因此你无需手动从 CPU 移动它们。
对于
use_orig_params=True,与ShardingStrategy.FULL_SHARD不同,ShardingStrategy.SHARD_GRAD_OP在前向传播后暴露的是非分片参数,而非分片参数。如果你想检查梯度,可以使用summon_full_params方法并设置with_grads=True。使用
limit_all_gathers=True时,你可能会在 FSDP 前向传播前看到 CPU 线程未发出任何核函数的间隙。这是有意为之的,表明速率限制器生效。通过这种方式同步 CPU 线程可以防止为后续的 all-gather 操作过度分配内存,且实际上不会延迟 GPU 核函数的执行。出于自动求导(autograd)相关的原因,FSDP 在前向和后向计算期间会将受管理的模块参数替换为
torch.Tensor视图。如果你的模块前向传播依赖于对参数的保存引用,而不是在每次迭代中重新获取引用,那么它将无法识别 FSDP 新创建的视图,从而导致自动求导无法正常工作。最后,当使用
sharding_strategy=ShardingStrategy.HYBRID_SHARD,且分片进程组为节点内(intra-node)、复制进程组为节点间(inter-node)时,设置NCCL_CROSS_NIC=1可以帮助改进某些集群配置中复制进程组的 all-reduce 时间。限制
使用 FSDP 时需要注意以下几点限制:
在使用 CPU 卸载(CPU offloading)时,FSDP 目前不支持
no_sync()之外的梯度累积。这是因为 FSDP 使用新归约(reduced)的梯度而不是与现有梯度累积,这可能导致结果不正确。FSDP 不支持运行包含在 FSDP 实例中的子模块的前向传播。这是因为子模块的参数将被分片,但子模块本身不是一个 FSDP 实例,因此其前向传播无法适当地 all-gather 完整参数。
由于其注册后向钩子(backward hooks)的方式,FSDP 不支持双向后向传播(double backwards)。
FSDP 在冻结参数时有一些约束。对于
use_orig_params=False,每个 FSDP 实例必须管理要么全部冻结,要么全部非冻结的参数。对于use_orig_params=True,FSDP 支持混合冻结和非冻结参数,但建议避免这样做,以防止梯度内存使用量高于预期。截至 PyTorch 1.12,FSDP 对共享参数的支持有限。如果你的用例需要增强共享参数支持,请发布在 此 issue 中。
你应该避免在不使用
summon_full_params上下文的情况下在前向和后向传播之间修改参数,因为修改可能不会持久保留。
- 参数:
module (nn.Module) – 这是要使用 FSDP 包装的模块。
process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 这是模型进行分片的进程组,因此也是用于 FSDP 的 all-gather 和 reduce-scatter 集合通信的进程组。如果为
None,则 FSDP 使用默认进程组。对于混合分片策略(如ShardingStrategy.HYBRID_SHARD),用户可以传入一个进程组元组,分别代表分片和复制的组。如果为None,FSDP 会为用户构造节点内分片和节点间复制的进程组。(默认值:None)sharding_strategy (Optional[ShardingStrategy]) – 配置分片策略,这可能会在内存节省和通信开销之间进行权衡。详见
ShardingStrategy。(默认值:FULL_SHARD)cpu_offload (Optional[CPUOffload]) – 配置 CPU 卸载。如果设置为
None,则不进行 CPU 卸载。详见CPUOffload。(默认值:None)auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) –
这指定了将 FSDP 应用于
module子模块的策略,这对于通信和计算的重叠是必需的,从而影响性能。如果为None,则 FSDP 仅应用于module,用户应手动将 FSDP 应用于父模块(采用自底向上方式)。为方便起见,此项直接接受ModuleWrapPolicy,允许用户指定要包装的模块类(例如 transformer 块)。否则,这应该是一个可调用对象,接受三个参数module: nn.Module,recurse: bool和nonwrapped_numel: int,并应返回一个bool,指定如果recurse=False时是否应对传入的module应用 FSDP,或者如果recurse=True时是否应继续遍历到模块的子树中。用户可以向可调用对象添加额外参数。torch.distributed.fsdp.wrap.py中的size_based_auto_wrap_policy提供了一个示例,如果模块子树中的参数超过 100M 个元素,则对其应用 FSDP。建议在应用 FSDP 后打印模型并根据需要进行调整。示例
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
backward_prefetch (Optional[BackwardPrefetch]) – 配置显式的 all-gather 后向预取。如果为
None,则 FSDP 不进行后向预取,且在后向传播中没有通信和计算重叠。详见BackwardPrefetch。(默认值:BACKWARD_PRE)mixed_precision (Optional[MixedPrecision]) – 为 FSDP 配置原生混合精度。如果设置为
None,则不使用混合精度。否则,可以设置参数、缓冲区和梯度归约的数据类型。详见MixedPrecision。(默认值:None)ignored_modules (Optional[Iterable[torch.nn.Module]]) – 该实例将忽略其自身参数以及子模块参数和缓冲区的模块。
ignored_modules中直接包含的任何模块都不应该是FullyShardedDataParallel实例;如果嵌套在此实例下,已经是构造好的FullyShardedDataParallel实例的子模块将不会被忽略。此参数可用于在使用auto_wrap_policy时避免在模块粒度上对特定参数进行分片,或者如果参数的分片不由 FSDP 管理。(默认值:None)param_init_fn (Optional[Callable[[nn.Module], None]]) –
一个
Callable[torch.nn.Module] -> None,指定当前位于元设备(meta device)上的模块应如何初始化到实际设备上。截至 v1.12,FSDP 通过is_meta检测参数或缓冲区在元设备上的模块,如果指定了param_init_fn,则应用它,否则调用nn.Module.reset_parameters()。在这两种情况下,实现都应仅初始化模块的参数/缓冲区,而不初始化其子模块的参数/缓冲区,以避免重复初始化。此外,FSDP 还通过 torchdistX (pytorch/torchdistX) 的deferred_init()API 支持延迟初始化,延迟模块由param_init_fn(如果指定)或 torchdistX 的默认materialize_module()初始化。如果指定了param_init_fn,则将其应用于所有元设备模块,这意味着它可能需要针对不同模块类型进行区分。FSDP 在参数展平和分片之前调用初始化函数。示例
>>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (Optional[Union[int, torch.device]]) – 一个
int或torch.device,指定 FSDP 初始化(包括必要时的模块初始化和参数分片)发生的 CUDA 设备。如果module位于 CPU 上,指定此参数可提高初始化速度。如果已设置默认 CUDA 设备(例如通过torch.cuda.set_device),则用户可以将torch.cuda.current_device传递给此项。(默认值:None)sync_module_states (bool) – 如果为
True,则每个 FSDP 模块将从 rank 0 广播模块参数和缓冲区,以确保它们在各 rank 间复制(此构造函数会增加通信开销)。这有助于以内存有效的方式通过load_state_dict加载state_dict检查点。示例请见FullStateDictConfig。(默认值:False)forward_prefetch (bool) – 如果为
True,则 FSDP 在当前前向计算之前显式预取下一个前向传播的 all-gather。这仅对 CPU 密集型工作负载有用,在这种情况下,更早发出下一个 all-gather 可能会改善重叠。由于预取遵循第一次迭代的执行顺序,因此这仅应用于静态图模型。(默认值:False)limit_all_gathers (bool) – 如果为
True,则 FSDP 显式同步 CPU 线程,以确保 GPU 内存使用量仅限于两个连续的 FSDP 实例(当前运行计算的实例和预取了 all-gather 的下一个实例)。如果为False,则 FSDP 允许 CPU 线程发出 all-gather 而无需任何额外同步。(默认值:True)我们将此功能称为“速率限制器”。仅对于内存压力较小的特定 CPU 密集型工作负载,才应将此标志设置为False,在这种情况下,CPU 线程可以在不考虑 GPU 内存使用的情况下积极发出所有核函数。use_orig_params (bool) – 将此设置为
True会让 FSDP 使用module的原始参数。FSDP 通过nn.Module.named_parameters()向用户暴露这些原始参数,而不是 FSDP 内部的FlatParameter。这意味着优化器步骤在原始参数上运行,从而支持针对每个原始参数的超参数。FSDP 保留原始参数变量,并在非分片形式和分片形式之间操纵它们的数据,其中它们始终分别是底层非分片或分片FlatParameter的视图。使用当前算法,分片形式始终是一维的,丢失了原始张量结构。对于给定的 rank,原始参数的数据可能全部、部分或全部缺失。如果缺失,其数据将类似于大小为 0 的空张量。用户不应编写依赖于其分片形式中给定原始参数存在哪些数据的程序。使用torch.compile()需要True。将此设置为False会通过nn.Module.named_parameters()向用户暴露 FSDP 的内部FlatParameter。(默认值:False)ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]) – 将不由该 FSDP 实例管理的已忽略参数或模块,这意味着参数不会被分片,且其梯度也不会在各 rank 间进行归约。此参数与现有的
ignored_modules参数统一,我们可能很快会弃用ignored_modules。为了向后兼容,我们同时保留ignored_states和 ignored_modules,但 FSDP 只允许其中之一被指定为非None。device_mesh (Optional[DeviceMesh]) – DeviceMesh 可以用作 process_group 的替代方案。传递 device_mesh 时,FSDP 将使用底层进程组进行 all-gather 和 reduce-scatter 集合通信。因此,这两个参数需要互斥。对于混合分片策略(如
ShardingStrategy.HYBRID_SHARD),用户可以传入一个二维 DeviceMesh 而不是进程组元组。对于二维 FSDP + TP,用户需要传入 device_mesh 而不是 process_group。有关 DeviceMesh 的更多信息,请访问:https://pytorch.ac.cn/tutorials/recipes/distributed_device_mesh.html
- apply(fn)[源码]#
将
fn递归应用于每个子模块(由.children()返回)以及自身。典型用法包括初始化模型的参数(另请参阅 torch.nn.init)。
与
torch.nn.Module.apply相比,此版本在应用fn之前还会收集完整参数。不应在另一个summon_full_params上下文中调用它。- 参数:
fn (
Module-> None) – 要应用于每个子模块的函数- 返回:
self
- 返回类型:
- clip_grad_norm_(max_norm, norm_type=2.0)[源码]#
对所有参数的梯度范数进行裁剪。
范数是基于所有参数梯度视为单个向量来计算的,且梯度是原地(in-place)修改的。
- 参数:
- 返回:
参数的总范数(视为单个向量)。
- 返回类型:
如果每个 FSDP 实例都使用
NO_SHARD,意味着梯度不在各 rank 间分片,那么你可以直接使用torch.nn.utils.clip_grad_norm_()。如果至少有一个 FSDP 实例使用分片策略(即
NO_SHARD以外的策略),那么你应该使用此方法而不是torch.nn.utils.clip_grad_norm_(),因为此方法处理了梯度在各 rank 间分片的情况。返回的总范数将根据 PyTorch 的类型提升语义,采用所有参数/梯度中“最高”的数据类型。例如,如果所有参数/梯度都使用低精度数据类型,则返回范数的类型将是该低精度数据类型,但如果存在至少一个使用 FP32 的参数/梯度,则返回范数的类型将是 FP32。
警告
由于该方法使用了集合通信,因此需要在所有 rank 上调用它。
- static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[源码]#
展平一个分片后的优化器状态字典(optimizer state-dict)。
该 API 类似于
shard_full_optim_state_dict()。唯一的区别是输入的sharded_optim_state_dict应从sharded_optim_state_dict()返回。因此,每个 rank 上都会有 all-gather 调用来收集ShardedTensor。- 参数:
sharded_optim_state_dict (Dict[str, Any]) – 对应于未展平参数并持有分片优化器状态的优化器状态字典。
model (torch.nn.Module) – 请参阅
shard_full_optim_state_dict()。optim (torch.optim.Optimizer) –
model参数的优化器。
- 返回:
- 返回类型:
- static fsdp_modules(module, root_only=False)[源码]#
返回所有嵌套的 FSDP 实例。
这可能包括
module本身;如果root_only=True,则仅包括 FSDP 根模块。- 参数:
- 返回:
嵌套在输入
module中的 FSDP 模块。- 返回类型:
List[FullyShardedDataParallel]
- static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source]#
返回完整的优化器状态字典(optimizer state-dict)。
将完整的优化器状态汇总到 rank 0,并以
dict的形式返回,遵循torch.optim.Optimizer.state_dict()的约定,即包含键"state"和"param_groups"。包含在model中的FSDP模块内的扁平化参数会被映射回其原始的非扁平化参数。由于此方法使用了集合通信,因此需要在所有 rank 上调用。但是,如果
rank0_only=True,则状态字典仅在 rank 0 上填充,其他所有 rank 将返回一个空的dict。与
torch.optim.Optimizer.state_dict()不同,此方法使用完整的参数名称作为键,而不是参数 ID。与
torch.optim.Optimizer.state_dict()一样,优化器状态字典中包含的张量不会被克隆,因此可能会出现别名(aliasing)方面的问题。作为最佳实践,请考虑立即保存返回的优化器状态字典,例如使用torch.save()。- 参数:
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数已传递给优化器optim。optim (torch.optim.Optimizer) –
model参数的优化器。optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器
optim的输入,表示参数组的list或参数的可迭代对象;如果为None,则此方法假定输入为model.parameters()。该参数已弃用,不再需要传入。(默认值:None)rank0_only (bool) – 如果为
True,则仅在 rank 0 上保存填充的dict;如果为False,则在所有 rank 上保存。(默认值:True)group (dist.ProcessGroup) – 模型的进程组,如果使用默认进程组,则为
None。(默认值:None)
- 返回:
一个包含
model原始非扁平化参数的优化器状态的dict,并包含遵循torch.optim.Optimizer.state_dict()约定的“state”和“param_groups”键。如果rank0_only=True,则非零 rank 返回一个空的dict。- 返回类型:
Dict[str, Any]
- static get_state_dict_type(module)[source]#
获取以
module为根的 FSDP 模块的 state_dict_type 及相应的配置。目标模块不必是 FSDP 模块。
- 返回:
一个包含当前设置的 state_dict_type 以及 state_dict / optim_state_dict 配置的
StateDictSettings。- 抛出:
AssertionError`(如果不同的 StateDictSettings) –
FSDP 子模块不同。 –
- 返回类型:
- named_buffers(*args, **kwargs)[source]#
返回模块缓冲区的迭代器,生成缓冲区名称和缓冲区本身。
在
summon_full_params()上下文管理器内时,拦截缓冲区名称并删除所有 FSDP 特定的扁平化缓冲区前缀。
- named_parameters(*args, **kwargs)[source]#
返回模块参数的迭代器,生成参数名称和参数本身。
在
summon_full_params()上下文管理器内时,拦截参数名称并删除所有 FSDP 特定的扁平化参数前缀。
- no_sync()[source]#
禁用跨 FSDP 实例的梯度同步。
在此上下文中,梯度将累积在模块变量中,并将在退出上下文后的第一次前向-反向传递中进行同步。这应该仅在根 FSDP 实例上使用,并将递归应用于所有子 FSDP 实例。
注意
这可能会导致更高的内存使用量,因为 FSDP 将累积完整的模型梯度(而不是梯度分片),直到最终同步。
注意
当与 CPU 卸载一起使用时,在上下文管理器内时,梯度不会被卸载到 CPU。相反,它们仅在最终同步之后才会被卸载。
- 返回类型:
- static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source]#
转换与分片模型相对应的优化器状态字典。
给定的状态字典可以转换为三种类型之一:1) 完整优化器状态字典,2) 分片优化器状态字典,3) 本地优化器状态字典。
对于完整优化器状态字典,所有状态均为非扁平化且未分片的。可以通过
state_dict_type()指定仅 rank0 和仅 CPU,以避免 OOM(内存不足)。对于分片优化器状态字典,所有状态均为非扁平化但已分片的。可以通过
state_dict_type()指定仅 CPU,以进一步节省内存。对于本地状态字典,不执行转换。但状态将从 nn.Tensor 转换为 ShardedTensor,以表示其分片性质(目前尚未支持)。
示例
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- 参数:
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数已传递给优化器optim。optim (torch.optim.Optimizer) –
model参数的优化器。optim_state_dict (Dict[str, Any]) – 要转换的目标优化器状态字典。如果值为 None,将使用 optim.state_dict()。(默认值:
None)group (dist.ProcessGroup) – 参数分片所在的模型的进程组,如果使用默认进程组,则为
None。(默认值:None)
- 返回:
一个包含
model优化器状态的dict。优化器状态的分片基于state_dict_type。- 返回类型:
Dict[str, Any]
- static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source]#
转换优化器状态字典,以便将其加载到与 FSDP 模型关联的优化器中。
给定一个通过
optim_state_dict()转换的optim_state_dict,它被转换为扁平化的优化器状态字典,该字典可以加载到optim中,而optim是model的优化器。model必须由 FullyShardedDataParallel 进行分片。>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> original_osd = optim.state_dict() >>> optim_state_dict = FSDP.optim_state_dict( >>> model, >>> optim, >>> optim_state_dict=original_osd >>> ) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- 参数:
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数已传递给优化器optim。optim (torch.optim.Optimizer) –
model参数的优化器。optim_state_dict (Dict[str, Any]) – 要加载的优化器状态。
is_named_optimizer (bool) – 此优化器是否为 NamedOptimizer 或 KeyedOptimizer。仅当
optim是 TorchRec 的 KeyedOptimizer 或 torch.distributed 的 NamedOptimizer 时,才将其设置为 True。load_directly (bool) – 如果设置为 True,此 API 在返回结果之前也会调用 optim.load_state_dict(result)。否则,用户负责调用
optim.load_state_dict()。(默认值:False)group (dist.ProcessGroup) – 参数分片所在的模型的进程组,如果使用默认进程组,则为
None。(默认值:None)
- 返回类型:
- register_comm_hook(state, hook)[source]#
注册一个通信钩子。
这是一个增强功能,为用户提供了灵活的钩子,他们可以指定 FSDP 如何在多个工作进程之间聚合梯度。此钩子可用于实现多种算法,例如 GossipGrad 和梯度压缩,这些算法在与
FullyShardedDataParallel训练时涉及不同的参数同步通信策略。警告
FSDP 通信钩子应该在运行初始前向传递之前注册,且只能注册一次。
- 参数:
state (object) –
传递给钩子以在训练过程中维护任何状态信息。示例包括梯度压缩中的误差反馈,GossipGrad 中接下来要通信的对等体等。它由每个工作进程本地存储,并由工作进程上的所有梯度张量共享。
hook (Callable) – 可调用对象,具有以下签名之一:1)
hook: Callable[torch.Tensor] -> None:此函数接收一个 Python 张量,它代表与该 FSDP 单元包装的模型(未被其他 FSDP 子单元包装的)对应的所有变量的完整、扁平化、未分片的梯度。然后执行所有必要的处理并返回None;2)hook: Callable[torch.Tensor, torch.Tensor] -> None:此函数接收两个 Python 张量,第一个代表与该 FSDP 单元包装的模型(未被其他 FSDP 子单元包装的)对应的所有变量的完整、扁平化、未分片的梯度。后者代表一个预设大小的张量,用于在归约(reduction)后存储分片梯度的一块。在两种情况下,可调用对象都会执行所有必要的处理并返回None。签名 1 的可调用对象应处理 NO_SHARD 情况下的梯度通信。签名 2 的可调用对象应处理分片情况下的梯度通信。
- static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source]#
将优化器状态字典
optim_state_dict的键重新映射为键类型optim_state_key_type。这可用于实现来自具有 FSDP 实例的模型和没有 FSDP 实例的模型的优化器状态字典之间的兼容性。
将 FSDP 完整优化器状态字典(即来自
full_optim_state_dict())重新键入以使用参数 ID 并可加载到未包装的模型中。>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
将来自未包装模型的正常优化器状态字典重新键入以使其可加载到已包装的模型中。
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- 返回:
使用
optim_state_key_type指定的参数键重新键入的优化器状态字典。- 返回类型:
Dict[str, Any]
- static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source]#
将完整的优化器状态字典从 rank 0 分散到所有其他 rank。
在每个 rank 上返回分片的优化器状态字典。返回值与
shard_full_optim_state_dict()相同,并且在 rank 0 上,第一个参数应该是full_optim_state_dict()的返回值。示例
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
注意
shard_full_optim_state_dict()和scatter_full_optim_state_dict()均可用于获取要加载的分片优化器状态字典。假设完整的优化器状态字典位于 CPU 内存中,前者要求每个 rank 都在 CPU 内存中拥有完整的字典,其中每个 rank 单独分片字典而无需任何通信,而后者仅要求 rank 0 在 CPU 内存中拥有完整的字典,其中 rank 0 将每个分片移动到 GPU 内存(对于 NCCL)并将其适当地通信给各个 rank。因此,前者具有更高的总 CPU 内存成本,而后者具有更高的通信成本。- 参数:
full_optim_state_dict (Optional[Dict[str, Any]]) – 对应于非扁平化参数的优化器状态字典,如果在 rank 0 上,则持有完整的未分片优化器状态;该参数在非零 rank 上被忽略。
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组的
list或参数的可迭代对象;如果为None,则此方法假定输入为model.parameters()。该参数已弃用,不再需要传入。(默认值:None)optim (Optional[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比
optim_input更推荐使用的参数。(默认值:None)group (dist.ProcessGroup) – 模型的进程组,如果使用默认进程组,则为
None。(默认值:None)
- 返回:
完整的优化器状态字典现在被重新映射到扁平化参数,而不是非扁平化参数,并且仅限于仅包含此 rank 的优化器状态部分。
- 返回类型:
Dict[str, Any]
- static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]#
设置目标模块的所有后代 FSDP 模块的
state_dict_type。还接受(可选)模型和优化器的状态字典配置。目标模块不必是 FSDP 模块。如果目标模块是 FSDP 模块,其
state_dict_type也将被更改。注意
此 API 应该仅对顶级(根)模块调用。
注意
此 API 使用户能够透明地使用传统的
state_dictAPI 来获取模型检查点,以防根 FSDP 模块被另一个nn.Module包装。例如,以下操作将确保在所有非 FSDP 实例上调用state_dict,同时分派到 FSDP 的 sharded_state_dict 实现。示例
>>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
- 参数:
module (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 要设置的目标
state_dict_type。state_dict_config (Optional[StateDictConfig]) – 目标
state_dict_type的配置。optim_state_dict_config (Optional[OptimStateDictConfig]) – 优化器状态字典的配置。
- 返回:
一个包含模块的前一个 state_dict 类型和配置的 StateDictSettings。
- 返回类型:
- static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source]#
分片完整的优化器状态字典。
将
full_optim_state_dict中的状态重新映射为扁平化参数,而不是非扁平化参数,并限制为仅包含此 rank 的优化器状态部分。第一个参数应该是full_optim_state_dict()的返回值。示例
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
注意
shard_full_optim_state_dict()和scatter_full_optim_state_dict()均可用于获取要加载的分片优化器状态字典。假设完整的优化器状态字典位于 CPU 内存中,前者要求每个 rank 都在 CPU 内存中拥有完整的字典,其中每个 rank 单独分片字典而无需任何通信,而后者仅要求 rank 0 在 CPU 内存中拥有完整的字典,其中 rank 0 将每个分片移动到 GPU 内存(对于 NCCL)并将其适当地通信给各个 rank。因此,前者具有更高的总 CPU 内存成本,而后者具有更高的通信成本。- 参数:
full_optim_state_dict (Dict[str, Any]) – 对应于非扁平化参数并持有完整未分片优化器状态的优化器状态字典。
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组的
list或参数的可迭代对象;如果为None,则此方法假定输入为model.parameters()。该参数已弃用,不再需要传入。(默认值:None)optim (Optional[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比
optim_input更推荐使用的参数。(默认值:None)
- 返回:
完整的优化器状态字典现在被重新映射到扁平化参数,而不是非扁平化参数,并且仅限于仅包含此 rank 的优化器状态部分。
- 返回类型:
Dict[str, Any]
- static sharded_optim_state_dict(model, optim, group=None)[source]#
以分片(sharded)形式返回优化器状态字典(optimizer state-dict)。
此 API 与
full_optim_state_dict()类似,但它会将所有非零维度的状态分块为ShardedTensor以节省内存。此 API 仅应在通过上下文管理器with state_dict_type(SHARDED_STATE_DICT):导出模型state_dict时使用。关于详细用法,请参阅
full_optim_state_dict()。警告
返回的状态字典包含
ShardedTensor,不能直接被常规的optim.load_state_dict使用。
- static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]#
设置目标模块的所有后代 FSDP 模块的
state_dict_type。此上下文管理器的功能与
set_state_dict_type()相同。有关详细信息,请阅读set_state_dict_type()的文档。示例
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict()
- 参数:
module (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 要设置的目标
state_dict_type。state_dict_config (Optional[StateDictConfig]) – 目标
state_dict_type的模型state_dict配置。optim_state_dict_config (Optional[OptimStateDictConfig]) – 目标
state_dict_type的优化器state_dict配置。
- 返回类型:
- static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source]#
使用此上下文管理器为 FSDP 实例暴露完整参数。
这在模型前向/反向传播*之后*非常有用,可以获取参数进行进一步处理或检查。它支持接收非 FSDP 模块,并会根据
recurse参数为所有包含的 FSDP 模块及其子模块召唤完整参数。注意
这也可用于嵌套的 FSDP。
注意
此上下文不能在正向或反向传播过程中使用。也不能在此上下文中启动正向或反向传播。
注意
在退出上下文管理器后,参数将恢复为本地分片状态,存储行为与前向传播相同。
注意
可以修改完整参数,但只有与本地参数分片对应的部分在上下文管理器退出后会持久化(除非设置
writeback=False,在这种情况下修改将被丢弃)。在 FSDP 不对参数进行分片的情况下(当前仅当world_size == 1或使用NO_SHARD配置时),无论writeback如何设置,修改都会持久化。注意
此方法适用于自身并非 FSDP 但可能包含多个独立 FSDP 单元的模块。在这种情况下,给定的参数将应用于所有包含的 FSDP 单元。
警告
请注意,目前不支持将
rank0_only=True与writeback=True结合使用,否则会引发错误。这是因为模型参数形状在上下文中各 rank 之间可能不同,且在退出上下文时写入它们可能会导致各 rank 间的不一致。警告
请注意,
offload_to_cpu和rank0_only=False会导致完整参数被冗余地复制到同一台机器上多个 GPU 的 CPU 内存中,这可能会引发 CPU 内存溢出 (OOM) 的风险。建议将offload_to_cpu与rank0_only=True配合使用。- 参数:
recurse (bool, Optional) – 是否递归地为嵌套的 FSDP 实例召唤所有参数(默认值:True)。
writeback (bool, Optional) – 如果为
False,则上下文管理器退出后对参数的修改将被丢弃;禁用此选项可以稍微提高效率(默认值:True)。rank0_only (bool, Optional) – 如果为
True,完整参数仅在全局 rank 0 上实例化。这意味着在上下文中,只有 rank 0 拥有完整参数,其他 rank 将拥有分片参数。请注意,不支持将rank0_only=True与writeback=True结合使用,因为模型参数形状在上下文中各 rank 之间可能不同,在退出上下文时写入它们可能导致各 rank 间的不一致。offload_to_cpu (bool, Optional) – 如果为
True,完整参数会被卸载到 CPU。请注意,这种卸载仅在参数被分片时才会发生(不分片的情况仅限于 world_size = 1 或NO_SHARD配置)。建议将offload_to_cpu与rank0_only=True配合使用,以避免将模型参数的冗余副本卸载到相同的 CPU 内存中。with_grads (bool, Optional) – 如果为
True,梯度也会随参数一起取消分片。目前,这仅在向 FSDP 构造函数传递use_orig_params=True且向此方法传递offload_to_cpu=False时才受支持。(默认值:False)
- 返回类型:
- class torch.distributed.fsdp.BackwardPrefetch(value)[source]#
此配置用于显式反向预取(backward prefetching),通过在反向传播中启用通信与计算重叠来提高吞吐量,代价是内存使用量略有增加。
BACKWARD_PRE:这能启用最大程度的重叠,但内存使用量增加最多。它在当前参数组的梯度计算*之前*预取下一组参数。这实现了*下一次 all-gather* 与*当前梯度计算*的重叠,在峰值时,内存中会同时持有当前参数组、下一组参数和当前梯度组。BACKWARD_POST:这能启用较少的重叠,但所需的内存较少。它在当前参数组的梯度计算*之后*预取下一组参数。这实现了*当前 reduce-scatter* 与*下一次梯度计算*的重叠,并且会在分配下一组参数的内存之前释放当前参数组的内存,峰值时内存中仅持有下一组参数和当前梯度组。FSDP 的
backward_prefetch参数接受None,这会完全禁用反向预取。这种情况下没有重叠,也不会增加内存使用量。通常我们不建议使用此设置,因为它可能会显著降低吞吐量。
更多技术背景:对于使用 NCCL 后端的单个进程组,任何集合通信(collectives)——即使是从不同流发出的——都会竞争同一个设备内的 NCCL 流,这意味着集合通信发出的相对顺序对于重叠至关重要。这两个反向预取值对应不同的发出顺序。
- class torch.distributed.fsdp.ShardingStrategy(value)[source]#
这指定了由
FullyShardedDataParallel用于分布式训练的分片策略。FULL_SHARD:参数、梯度和优化器状态均被分片。对于参数,此策略在前向传播前取消分片(通过 all-gather),在前向传播后重新分片,在反向计算前取消分片,并在反向计算后重新分片。对于梯度,它在反向计算后同步并分片(通过 reduce-scatter)。分片后的优化器状态在每个 rank 上本地更新。SHARD_GRAD_OP:梯度和优化器状态在计算过程中被分片,此外参数在计算之外被分片。对于参数,此策略在前向传播前取消分片,在前向传播后不重新分片,仅在反向计算后重新分片。分片后的优化器状态在每个 rank 上本地更新。在no_sync()内部,参数在反向计算后不会重新分片。NO_SHARD:参数、梯度和优化器状态不分片,而是类似于 PyTorch 的DistributedDataParallelAPI 在各 rank 间进行复制。对于梯度,此策略在反向计算后同步它们(通过 all-reduce)。未分片的优化器状态在每个 rank 上本地更新。HYBRID_SHARD:在节点内应用FULL_SHARD,并在节点间复制参数。这减少了通信量,因为昂贵的 all-gather 和 reduce-scatter 仅在节点内进行,这对于中等规模的模型可能更具性能优势。_HYBRID_SHARD_ZERO2:在节点内应用SHARD_GRAD_OP,并在节点间复制参数。这类似于HYBRID_SHARD,但可能会提供更高的吞吐量,因为在前向传播后未分片的参数不会被释放,从而节省了后向传播前的 all-gather。
- class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source]#
此配置用于配置 FSDP 原生的混合精度训练。
- 变量:
param_dtype (Optional[torch.dtype]) – 指定模型参数在前向和反向传播期间的 dtype,从而指定前向和反向计算的 dtype。在前向和反向传播之外,*分片*参数保持全精度(例如用于优化器步骤),且对于模型检查点,参数总是以全精度保存。(默认值:
None)reduce_dtype (Optional[torch.dtype]) – 指定梯度归约(即 reduce-scatter 或 all-reduce)的 dtype。如果此项为
None但param_dtype不为None,则此项将采用param_dtype的值,仍然以低精度运行梯度归约。此项被允许与param_dtype不同,例如强制以全精度运行梯度归约。(默认值:None)buffer_dtype (Optional[torch.dtype]) – 指定缓冲区(buffers)的 dtype。FSDP 不会对缓冲区进行分片。相反,FSDP 在第一次前向传播中将其转换为
buffer_dtype,并在之后保持该 dtype。对于模型检查点,缓冲区以全精度保存,除了LOCAL_STATE_DICT。(默认值:None)keep_low_precision_grads (bool) – 如果为
False,FSDP 会在反向传播之后将梯度上转(upcast)为全精度,以备优化器步骤使用。如果为True,FSDP 将梯度保留为用于梯度归约的 dtype,如果使用支持在低精度下运行的自定义优化器,这可以节省内存。(默认值:False)cast_forward_inputs (bool) – 如果为
True,该 FSDP 模块将其前向传播的 args 和 kwargs 转换为param_dtype。这是为了确保参数和输入的 dtypes 在前向计算时匹配,正如许多算子所要求的那样。当仅对部分而非全部 FSDP 模块应用混合精度时,可能需要将其设置为True,在这种情况下,混合精度 FSDP 子模块需要重新转换其输入。(默认值:False)cast_root_forward_inputs (bool) – 如果为
True,根 FSDP 模块将其前向传播的 args 和 kwargs 转换为param_dtype,并覆盖cast_forward_inputs的值。对于非根 FSDP 模块,此项不起作用。(默认值:True)_module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) – 指定使用
auto_wrap_policy时忽略混合精度的模块类:这些类的模块将单独应用 FSDP,并禁用混合精度(这意味着最终的 FSDP 构建会偏离指定的策略)。如果未指定auto_wrap_policy,则此项不起作用。此 API 为实验性,可能会发生变化。(默认值:(_BatchNorm,))
注意
此 API 为实验性,可能会发生变化。
注意
仅浮点张量会被转换为指定的 dtypes。
注意
在
summon_full_params中,参数被强制设为全精度,但缓冲区则不会。注意
Layer Norm 和 Batch Norm 即使在输入为如
float16或bfloat16等低精度时,也会在float32中进行累积。仅针对这些归一化模块禁用 FSDP 的混合精度,意味着仿射参数保持在float32。然而,这会导致这些归一化模块进行额外的 all-gather 和 reduce-scatter,这可能效率较低,因此如果工作负载允许,用户应优先选择对这些模块也应用混合精度。注意
默认情况下,如果用户传入的模型包含任何
_BatchNorm模块并指定了auto_wrap_policy,那么 Batch Norm 模块将单独应用 FSDP 并禁用混合精度。请参阅_module_classes_to_ignore参数。注意
MixedPrecision默认具有cast_root_forward_inputs=True和cast_forward_inputs=False。对于根 FSDP 实例,其cast_root_forward_inputs的优先级高于其cast_forward_inputs。对于非根 FSDP 实例,其cast_root_forward_inputs值将被忽略。此默认设置足以应对通常的情况,即每个 FSDP 实例具有相同的MixedPrecision配置,并且仅需要在模型前向传播开始时将输入转换为param_dtype。注意
对于具有不同
MixedPrecision配置的嵌套 FSDP 实例,我们建议设置各个cast_forward_inputs值,以配置是否在每个实例前向传播前进行转换。在这种情况下,由于转换发生在每个 FSDP 实例的前向传播之前,父 FSDP 实例应确保其非 FSDP 子模块先于其 FSDP 子模块运行,以避免激活值的 dtype 因不同的MixedPrecision配置而发生改变。示例
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> )
上面展示了一个有效示例。另一方面,如果
model[1]被替换为model[0],这意味着使用不同MixedPrecision的子模块先运行了前向传播,那么model[1]将错误地接收到float16激活值,而不是bfloat16激活值。
- class torch.distributed.fsdp.CPUOffload(offload_params=False)[source]#
此配置用于配置 CPU 卸载。
- 变量:
offload_params (bool) – 指定当参数未参与计算时是否将其卸载到 CPU。如果为
True,则同时会将梯度卸载到 CPU,这意味着优化器步骤将在 CPU 上运行。
- class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source]#
StateDictConfig是所有state_dict配置类的基类。用户应实例化一个子类(例如FullStateDictConfig)来配置 FSDP 支持的相应state_dict类型的设置。- 变量:
offload_to_cpu (bool) – 如果为
True,FSDP 将 state dict 值卸载到 CPU;如果为False,FSDP 将其保留在 GPU 上。(默认值:False)
- class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[source]#
FullStateDictConfig是一个旨在与StateDictType.FULL_STATE_DICT一起使用的配置类。我们建议在保存完整 state dict 时同时启用offload_to_cpu=True和rank0_only=True,以分别节省 GPU 内存和 CPU 内存。此配置类旨在通过state_dict_type()上下文管理器使用,如下所示:>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> fsdp = FSDP(model, auto_wrap_policy=...) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP >>> if dist.get_rank() == 0: >>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument >>> # communicates loaded checkpoint states from rank 0 to rest of the world. >>> fsdp = FSDP( ... model, ... device_id=torch.cuda.current_device(), ... auto_wrap_policy=..., ... sync_module_states=True, ... ) >>> # After this point, all ranks have FSDP model with loaded checkpoint.
- 变量:
rank0_only (bool) – 如果为
True,则仅 rank 0 保存完整 state dict,其他非零 rank 保存空字典。如果为False,则所有 rank 都保存完整 state dict。(默认值:False)
- class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[源码]#
ShardedStateDictConfig是一个配置类,旨在与StateDictType.SHARDED_STATE_DICT配合使用。- 变量:
_use_dtensor (bool) – 如果为
True,则 FSDP 将 state dict 值保存为DTensor;如果为False,则 FSDP 将其保存为ShardedTensor。(默认值:False)
警告
_use_dtensor是ShardedStateDictConfig的私有字段,由 FSDP 用于确定 state dict 值的类型。用户不应手动修改_use_dtensor。
- class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[源码]#
OptimStateDictConfig是所有optim_state_dict配置类的基类。用户应实例化一个子类(例如FullOptimStateDictConfig),以便为 FSDP 支持的相应optim_state_dict类型配置设置。- 变量:
offload_to_cpu (bool) – 如果为
True,则 FSDP 将 state dict 的张量值卸载(offload)到 CPU;如果为False,则 FSDP 将它们保留在原始设备上(除非启用了参数 CPU 卸载,否则为 GPU)。(默认值:True)
- class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[源码]#
- 变量:
rank0_only (bool) – 如果为
True,则仅 rank 0 保存完整 state dict,其他非零 rank 保存空字典。如果为False,则所有 rank 都保存完整 state dict。(默认值:False)
- class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[源码]#
ShardedOptimStateDictConfig是一个配置类,旨在与StateDictType.SHARDED_STATE_DICT配合使用。- 变量:
_use_dtensor (bool) – 如果为
True,则 FSDP 将 state dict 值保存为DTensor;如果为False,则 FSDP 将其保存为ShardedTensor。(默认值:False)
警告
_use_dtensor是ShardedOptimStateDictConfig的私有字段,由 FSDP 用于确定 state dict 值的类型。用户不应手动修改_use_dtensor。
- class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[源码]#