FullyShardedDataParallel#
创建日期:2022年2月2日 | 最后更新日期:2025年6月11日
- class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[source]#
一个用于在数据并行工作器之间分片模块参数的封装器。
这受到了 Xu 等人 的启发,以及 DeepSpeed 的 ZeRO 阶段 3。FullyShardedDataParallel 通常简称为 FSDP。
要了解 FSDP 内部机制,请参阅 FSDP 笔记。
示例
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
使用 FSDP 需要封装您的模块,然后初始化优化器。这是必需的,因为 FSDP 会更改参数变量。
设置 FSDP 时,您需要考虑目标 CUDA 设备。如果设备有 ID (
dev_id
),您有三种选择:将模块放置在该设备上。
使用
torch.cuda.set_device(dev_id)
设置设备。将
dev_id
传入device_id
构造函数参数。
这确保 FSDP 实例的计算设备是目标设备。对于选项 1 和 3,FSDP 初始化始终在 GPU 上进行。对于选项 2,FSDP 初始化发生在模块的当前设备上,这可能是 CPU。
如果您使用
sync_module_states=True
标志,您需要确保模块在 GPU 上,或者使用device_id
参数指定 FSDP 将在 FSDP 构造函数中将模块移动到的 CUDA 设备。这是必需的,因为sync_module_states=True
需要 GPU 通信。FSDP 还会将输入张量移动到 forward 方法的 GPU 计算设备,因此您无需手动将它们从 CPU 移动。
对于
use_orig_params=True
,ShardingStrategy.SHARD_GRAD_OP
公开未分片的参数,而不是 forward 之后的分片参数,这与ShardingStrategy.FULL_SHARD
不同。如果您想检查梯度,可以使用summon_full_params
方法并设置with_grads=True
。使用
limit_all_gathers=True
时,您可能会在 FSDP pre-forward 中看到 CPU 线程没有发出任何内核的间隙。这是有意为之,显示了速率限制器正在生效。以这种方式同步 CPU 线程可以防止为后续的 all-gather 操作过度分配内存,并且实际上不应延迟 GPU 内核的执行。出于 autograd 相关的原因,FSDP 在前向和后向计算期间用
torch.Tensor
视图替换了托管模块的参数。如果您的模块的前向依赖于保存的参数引用而不是每次迭代重新获取引用,那么它将无法看到 FSDP 新创建的视图,并且 autograd 将无法正常工作。最后,当使用
sharding_strategy=ShardingStrategy.HYBRID_SHARD
且分片进程组为节点内,复制进程组为节点间时,在某些集群设置中,设置NCCL_CROSS_NIC=1
可以帮助改善复制进程组上的 all-reduce 时间。限制
使用 FSDP 时需要注意以下几个限制:
当使用 CPU 卸载时,FSDP 目前不支持在
no_sync()
之外的梯度累积。这是因为 FSDP 使用新减少的梯度而不是与任何现有梯度累积,这可能导致不正确的结果。FSDP 不支持运行包含在 FSDP 实例中的子模块的前向传递。这是因为子模块的参数将被分片,但子模块本身不是 FSDP 实例,因此其前向传递将无法正确地 all-gather 完整的参数。
由于其注册后向钩子的方式,FSDP 不支持双重后向。
FSDP 在冻结参数时有一些限制。对于
use_orig_params=False
,每个 FSDP 实例必须管理全部冻结或全部未冻结的参数。对于use_orig_params=True
,FSDP 支持混合冻结和未冻结的参数,但建议避免这样做,以防止出现高于预期的梯度内存使用量。截至 PyTorch 1.12,FSDP 对共享参数的支持有限。如果您的用例需要增强的共享参数支持,请在此 issue 中发布。
您应该避免在前向和后向之间修改参数而不使用
summon_full_params
上下文,因为修改可能不会持久。
- 参数
module (nn.Module) – 要用 FSDP 封装的模块。
process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 这是模型分片所用的进程组,因此也是 FSDP 进行 all-gather 和 reduce-scatter 集合通信所用的进程组。如果为
None
,则 FSDP 使用默认进程组。对于混合分片策略,例如ShardingStrategy.HYBRID_SHARD
,用户可以传入一个进程组元组,分别表示用于分片和复制的组。如果为None
,则 FSDP 会为用户构造进程组,以便在节点内分片并在节点间复制。(默认值:None
)sharding_strategy (Optional[ShardingStrategy]) – 这配置了分片策略,可能会在内存节省和通信开销之间进行权衡。有关详细信息,请参阅
ShardingStrategy
。(默认值:FULL_SHARD
)cpu_offload (Optional[CPUOffload]) – 这配置了 CPU 卸载。如果设置为
None
,则不进行 CPU 卸载。有关详细信息,请参阅CPUOffload
。(默认值:None
)auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) –
这指定了将 FSDP 应用于
module
子模块的策略,这对于通信和计算的重叠至关重要,从而影响性能。如果为None
,则 FSDP 仅应用于module
,用户应手动(自下而上)将 FSDP 应用于父模块。为方便起见,此参数直接接受ModuleWrapPolicy
,允许用户指定要封装的模块类(例如,transformer 块)。否则,此参数应为一个可调用对象,接受三个参数module: nn.Module
、recurse: bool
和nonwrapped_numel: int
,并返回一个bool
值,指示是否应将 FSDP 应用于传入的module
(如果recurse=False
),或遍历是否应继续进入模块的子树(如果recurse=True
)。用户可以向可调用对象添加额外参数。torch.distributed.fsdp.wrap.py
中的size_based_auto_wrap_policy
提供了一个示例可调用对象,它在子树中的参数数量超过 1 亿时将 FSDP 应用于模块。我们建议在应用 FSDP 后打印模型并根据需要进行调整。示例
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
backward_prefetch (Optional[BackwardPrefetch]) – 这配置了 all-gather 的显式反向预取。如果为
None
,则 FSDP 不进行反向预取,并且在反向传递中没有通信和计算重叠。有关详细信息,请参阅BackwardPrefetch
。(默认值:BACKWARD_PRE
)mixed_precision (Optional[MixedPrecision]) – 这配置了 FSDP 的原生混合精度。如果设置为
None
,则不使用混合精度。否则,可以设置参数、缓冲区和梯度约简的数据类型。有关详细信息,请参阅MixedPrecision
。(默认值:None
)ignored_modules (Optional[Iterable[torch.nn.Module]]) – 该实例忽略其自身参数和子模块参数及缓冲区的模块。在
ignored_modules
中的任何模块都不应该是FullyShardedDataParallel
实例,并且任何已经构建的FullyShardedDataParallel
实例(如果嵌套在此实例下)将不会被忽略。此参数可用于在使用auto_wrap_policy
或参数分片不由 FSDP 管理时,避免在模块粒度上分片特定参数。(默认值:None
)param_init_fn (可选[可调用[[nn.Module], 无]]) –
一个
Callable[torch.nn.Module] -> None
,它指定了目前在 meta 设备上的模块应如何初始化到实际设备上。自 v1.12 起,FSDP 通过is_meta
检测参数或缓冲区在 meta 设备上的模块,如果指定了param_init_fn
则应用它,否则调用nn.Module.reset_parameters()
。对于这两种情况,实现都应该只初始化模块的参数/缓冲区,而不是其子模块的参数/缓冲区。这是为了避免重复初始化。此外,FSDP 还支持通过 torchdistX (pytorch/torchdistX) 的deferred_init()
API 进行延迟初始化,其中延迟模块在指定param_init_fn
时通过调用它进行初始化,否则使用 torchdistX 的默认materialize_module()
。如果指定了param_init_fn
,则它将应用于所有元设备模块,这意味着它可能需要根据模块类型进行处理。FSDP 在参数展平(flattening)和分片(sharding)之前调用初始化函数。示例
>>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (可选[联合[int, torch.device]]) – 一个
int
或torch.device
,给出 FSDP 初始化发生时使用的 CUDA 设备,包括模块初始化(如果需要)和参数分片。如果module
在 CPU 上,应指定此参数以提高初始化速度。如果已设置默认 CUDA 设备(例如通过torch.cuda.set_device
),则用户可以传入torch.cuda.current_device
。(默认值:None
)sync_module_states (bool) – 如果为
True
,则每个 FSDP 模块会将模块参数和缓冲区从 rank 0 广播出去,以确保它们在所有 rank 上都已复制(这会增加此构造函数的通信开销)。这有助于以内存高效的方式通过load_state_dict
加载state_dict
检查点。请参阅FullStateDictConfig
了解示例。(默认值:False
)forward_prefetch (bool) – 如果为
True
,则 FSDP 会在当前前向计算之前显式地预取下一次前向传递的 all-gather。这仅对 CPU 密集型工作负载有用,在这种情况下,更早地发出下一次 all-gather 可能会改善重叠。这应该只用于静态图模型,因为预取遵循第一次迭代的执行顺序。(默认值:False
)limit_all_gathers (bool) – 如果为
True
,FSDP 会显式地同步 CPU 线程,以确保 GPU 内存仅由两个连续的 FSDP 实例使用(当前正在进行计算的实例和其 all-gather 被预取的下一个实例)。如果为False
,FSDP 允许 CPU 线程在没有任何额外同步的情况下发出 all-gather。(默认值:True
) 我们通常将此功能称为“速率限制器”。此标志只应在 CPU 密集型且内存压力较低的特定工作负载中设置为False
,在这种情况下,CPU 线程可以积极地发出所有内核,而无需担心 GPU 内存使用情况。use_orig_params (bool) – 将此设置为
True
会使 FSDP 使用module
的原始参数。FSDP 通过nn.Module.named_parameters()
而不是 FSDP 的内部FlatParameter
将这些原始参数暴露给用户。这意味着优化器步骤在原始参数上运行,从而实现按原始参数的超参数。FSDP 保留原始参数变量,并在未分片和分片形式之间操作其数据,其中它们始终是底层未分片或分片FlatParameter
的视图。使用当前算法,分片形式始终是 1D 的,丢失了原始张量结构。原始参数可能具有全部、部分或不具有其数据,具体取决于给定 rank。在不具有数据的情况下,其数据将类似于大小为 0 的空张量。用户不应编写依赖于给定原始参数在其分片形式中存在哪些数据的程序。True
是使用torch.compile()
所必需的。将此设置为False
会通过nn.Module.named_parameters()
将 FSDP 的内部FlatParameter
暴露给用户。(默认值:False
)ignored_states (可选[可迭代[torch.nn.Parameter]], 可选[可迭代[torch.nn.Module]]) – 将不由此 FSDP 实例管理的被忽略的参数或模块,这意味着参数不会被分片,并且其梯度不会在 rank 之间进行归约。此参数与现有的
ignored_modules
参数统一,我们可能很快就会弃用ignored_modules
。为了向后兼容,我们保留了ignored_states
和 `ignored_modules`,但 FSDP 只允许其中一个被指定为非None
。device_mesh (可选[DeviceMesh]) – DeviceMesh 可以用作 process_group 的替代品。当传入 device_mesh 时,FSDP 将使用底层的 process group 进行 all-gather 和 reduce-scatter 集合通信。因此,这两个参数需要互斥。对于混合分片策略,如
ShardingStrategy.HYBRID_SHARD
,用户可以传入 2D DeviceMesh 而不是 process group 的元组。对于 2D FSDP + TP,用户需要传入 device_mesh 而不是 process_group。有关 DeviceMesh 的更多信息,请访问:https://pytorch.ac.cn/tutorials/recipes/distributed_device_mesh.html
- apply(fn)[source]#
将
fn
递归应用于每个子模块(由.children()
返回)以及自身。典型用途包括初始化模型的参数(另请参阅 torch.nn.init)。
与
torch.nn.Module.apply
相比,此版本在应用fn
之前会额外收集完整的参数。它不应从另一个summon_full_params
上下文内部调用。- 参数
fn (
Module
-> None) – 要应用于每个子模块的函数- 返回
自身
- 返回类型
- clip_grad_norm_(max_norm, norm_type=2.0)[source]#
裁剪所有参数的梯度范数。
范数是在所有参数的梯度视为单个向量时计算的,并且梯度是就地修改的。
- 参数
- 返回
参数的总范数(视为单个向量)。
- 返回类型
如果每个 FSDP 实例都使用
NO_SHARD
,这意味着没有梯度在 rank 之间分片,则可以直接使用torch.nn.utils.clip_grad_norm_()
。如果至少有一些 FSDP 实例使用分片策略(即,不是
NO_SHARD
的策略),则应使用此方法而不是torch.nn.utils.clip_grad_norm_()
,因为此方法处理了梯度在 rank 之间分片的事实。返回的总范数将具有 PyTorch 类型提升语义定义的,所有参数/梯度中“最大”的数据类型。例如,如果所有参数/梯度都使用低精度数据类型,则返回的范数的数据类型将是该低精度数据类型,但如果存在至少一个使用 FP32 的参数/梯度,则返回的范数的数据类型将是 FP32。
警告
由于它使用集体通信,因此需要在所有 rank 上调用此函数。
- static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source]#
展平分片优化器状态字典。
API 类似于
shard_full_optim_state_dict()
。唯一的区别是输入sharded_optim_state_dict
应该从sharded_optim_state_dict()
返回。因此,每个 rank 上都会有 all-gather 调用来收集ShardedTensor
。- 参数
sharded_optim_state_dict (字典[str, 任意]) – 与未展平参数对应的优化器状态字典,并包含分片的优化器状态。
model (torch.nn.Module) – 请参阅
shard_full_optim_state_dict()
。optim (torch.optim.Optimizer) –
model
参数的优化器。
- 返回
- 返回类型
- static fsdp_modules(module, root_only=False)[source]#
返回所有嵌套的 FSDP 实例。
这可能包括
module
本身,并且只有当root_only=True
时才包括 FSDP 根模块。- 参数
module (torch.nn.Module) – 根模块,它可能是也可能不是
FSDP
模块。root_only (bool) – 是否只返回 FSDP 根模块。(默认值:
False
)
- 返回
嵌套在输入
module
中的 FSDP 模块。- 返回类型
- static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source]#
返回完整的优化器状态字典。
将完整优化器状态合并到 rank 0,并以
dict
形式返回,遵循torch.optim.Optimizer.state_dict()
的约定,即包含键"state"
和"param_groups"
。包含在model
中的FSDP
模块中展平的参数被映射回其未展平的参数。由于它使用集体通信,因此需要在所有 rank 上调用此函数。但是,如果
rank0_only=True
,则状态字典仅在 rank 0 上填充,所有其他 rank 返回一个空的dict
。与
torch.optim.Optimizer.state_dict()
不同,此方法使用完整的参数名作为键,而不是参数 ID。与
torch.optim.Optimizer.state_dict()
类似,优化器状态字典中包含的张量未被克隆,因此可能存在别名(aliasing)问题。为获得最佳实践,请考虑立即保存返回的优化器状态字典,例如使用torch.save()
。- 参数
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel
实例),其参数已传递给优化器optim
。optim (torch.optim.Optimizer) –
model
参数的优化器。optim_input (可选[联合[列表[字典[str, 任意]], 可迭代[torch.nn.Parameter]]]) – 传入优化器
optim
的输入,表示参数组的list
或参数的迭代器;如果为None
,则此方法假定输入为model.parameters()
。此参数已弃用,不再需要传入。(默认值:None
)rank0_only (bool) – 如果为
True
,则仅在 rank 0 上保存填充的dict
;如果为False
,则在所有 rank 上保存。(默认值:True
)group (dist.ProcessGroup) – 模型的进程组,如果使用默认进程组,则为
None
。(默认值:None
)
- 返回
一个
dict
,包含model
原始未展平参数的优化器状态,并包含键“state”和“param_groups”,遵循torch.optim.Optimizer.state_dict()
的约定。如果rank0_only=True
,则非零 rank 返回一个空的dict
。- 返回类型
字典[str, Any]
- static get_state_dict_type(module)[source]#
获取以
module
为根的 FSDP 模块的 state_dict_type 和相应的配置。目标模块不必是 FSDP 模块。
- 返回
一个
StateDictSettings
,包含当前设置的 state_dict_type 以及 state_dict/optim_state_dict 配置。- 引发
AssertionError` 如果不同 –
FSDP 子模块的 StateDictSettings 不同。 –
- 返回类型
- named_buffers(*args, **kwargs)[source]#
返回一个迭代器,遍历模块缓冲区,生成缓冲区名称和缓冲区本身。
当在
summon_full_params()
上下文管理器内部时,截取缓冲区名称并删除所有出现的 FSDP 特定的展平缓冲区前缀。- 返回类型
迭代器[元组[str, torch.Tensor]]
- named_parameters(*args, **kwargs)[source]#
返回一个迭代器,遍历模块参数,生成参数名称和参数本身。
当在
summon_full_params()
上下文管理器内部时,截取参数名称并删除所有出现的 FSDP 特定的展平参数前缀。- 返回类型
- no_sync()[source]#
禁用 FSDP 实例之间的梯度同步。
在此上下文内,梯度将累积在模块变量中,稍后将在退出上下文后的第一次前向-后向传递中同步。这仅应在根 FSDP 实例上使用,并将递归应用于所有子 FSDP 实例。
注意
这可能会导致更高的内存使用,因为 FSDP 将累积完整的模型梯度(而不是梯度分片),直到最终同步。
注意
与 CPU 卸载一起使用时,在上下文管理器内部时,梯度将不会卸载到 CPU。相反,它们只会在最终同步后立即卸载。
- 返回类型
- static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source]#
转换与分片模型对应的优化器状态字典。
给定的状态字典可以转换为三种类型之一:1) 完整优化器状态字典,2) 分片优化器状态字典,3) 局部优化器状态字典。
对于完整的优化器状态字典,所有状态都未展平且未分片。可以通过
state_dict_type()
指定仅限 Rank0 和仅限 CPU,以避免内存不足 (OOM)。对于分片优化器状态字典,所有状态都已展平但已分片。可以通过
state_dict_type()
指定仅限 CPU,以进一步节省内存。对于局部状态字典,不会执行任何转换。但是,状态将从 nn.Tensor 转换为 ShardedTensor,以表示其分片特性(目前尚不支持)。
示例
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- 参数
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel
实例),其参数已传递给优化器optim
。optim (torch.optim.Optimizer) –
model
参数的优化器。optim_state_dict (字典[str, 任意]) – 目标优化器状态字典,用于转换。如果值为 None,将使用 optim.state_dict()。(默认值:
None
)group (dist.ProcessGroup) – 模型参数分片的进程组,如果使用默认进程组,则为
None
。(默认值:None
)
- 返回
一个
dict
,包含model
的优化器状态。优化器状态的分片基于state_dict_type
。- 返回类型
字典[str, Any]
- static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source]#
转换优化器状态字典,使其可以加载到与 FSDP 模型关联的优化器中。
给定一个通过
optim_state_dict()
转换的optim_state_dict
,它将被转换为展平的优化器状态字典,可以加载到optim
中,而optim
是model
的优化器。model
必须由 FullyShardedDataParallel 进行分片。>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> original_osd = optim.state_dict() >>> optim_state_dict = FSDP.optim_state_dict( >>> model, >>> optim, >>> optim_state_dict=original_osd >>> ) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- 参数
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel
实例),其参数已传递给优化器optim
。optim (torch.optim.Optimizer) –
model
参数的优化器。optim_state_dict (字典[str, 任意]) – 要加载的优化器状态。
is_named_optimizer (bool) – 此优化器是否为 NamedOptimizer 或 KeyedOptimizer。仅当
optim
是 TorchRec 的 KeyedOptimizer 或 torch.distributed 的 NamedOptimizer 时设置为 True。load_directly (bool) – 如果此参数设置为 True,则此 API 将在返回结果之前调用 optim.load_state_dict(result)。否则,用户需要自行调用
optim.load_state_dict()
。(默认值:False
)group (dist.ProcessGroup) – 模型参数分片的进程组,如果使用默认进程组,则为
None
。(默认值:None
)
- 返回类型
- register_comm_hook(state, hook)[source]#
注册通信钩子。
这是一项增强功能,为用户提供了灵活的钩子,他们可以指定 FSDP 如何在多个工作进程之间聚合梯度。此钩子可用于实现多种算法,如 GossipGrad 和梯度压缩,这些算法涉及在使用
FullyShardedDataParallel
训练时参数同步的不同通信策略。警告
FSDP 通信钩子应在运行初始前向传递之前注册,且只能注册一次。
- 参数
state (对象) –
传递给钩子以在训练过程中维护任何状态信息。示例包括梯度压缩中的错误反馈、GossipGrad 中与下一个通信的对等体等。它由每个工作者本地存储,并由工作者上的所有梯度张量共享。
hook (可调用) – 可调用对象,具有以下签名之一:1)
hook: Callable[torch.Tensor] -> None
:此函数接收一个 Python 张量,表示与此 FSDP 单元封装的模型(未被其他 FSDP 子单元封装)对应的所有变量的完整、展平、未分片梯度。然后执行所有必要的处理并返回None
;2)hook: Callable[torch.Tensor, torch.Tensor] -> None
:此函数接收两个 Python 张量,第一个表示与此 FSDP 单元封装的模型(未被其他 FSDP 子单元封装)对应的所有变量的完整、展平、未分片梯度。第二个表示一个预先确定大小的张量,用于在归约后存储分片梯度的一个块。在这两种情况下,可调用对象执行所有必要的处理并返回None
。签名 1 的可调用对象预计处理 NO_SHARD 情况下的梯度通信。签名 2 的可调用对象预计处理分片情况下的梯度通信。
- static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source]#
将优化器状态字典
optim_state_dict
的键重新映射为optim_state_key_type
指定的键类型。这可用于实现带有 FSDP 实例的模型与没有 FSDP 实例的模型之间的优化器状态字典的兼容性。
将 FSDP 完整优化器状态字典(即来自
full_optim_state_dict()
)的键重新映射为使用参数 ID,并使其可加载到未封装的模型中。>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
将来自未封装模型的正常优化器状态字典的键重新映射,使其可加载到已封装的模型中。
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- 返回
使用
optim_state_key_type
指定的参数键重新键入的优化器状态字典。- 返回类型
字典[str, Any]
- static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source]#
将完整的优化器状态字典从 rank 0 分散到所有其他 rank。
返回每个 rank 上的分片优化器状态字典。返回值与
shard_full_optim_state_dict()
相同,并且在 rank 0 上,第一个参数应该是full_optim_state_dict()
的返回值。示例
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
注意
shard_full_optim_state_dict()
和scatter_full_optim_state_dict()
都可以用于获取要加载的分片优化器状态字典。假设完整的优化器状态字典驻留在 CPU 内存中,前者要求每个 rank 在 CPU 内存中拥有完整的字典,其中每个 rank 单独分片字典而无需任何通信,而后者仅要求 rank 0 在 CPU 内存中拥有完整的字典,其中 rank 0 将每个分片移动到 GPU 内存(对于 NCCL)并将其适当地通信到相应的 rank。因此,前者具有更高的总 CPU 内存成本,而后者具有更高的通信成本。- 参数
full_optim_state_dict (可选[字典[str, 任意]]) – 如果在 rank 0 上,则为与未展平参数对应的优化器状态字典,并包含完整的非分片优化器状态;在非零 rank 上,该参数将被忽略。
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel
实例),其参数与full_optim_state_dict
中的优化器状态相对应。optim_input (可选[联合[列表[字典[str, 任意]], 可迭代[torch.nn.Parameter]]]) – 传入优化器的输入,表示参数组的
list
或参数的迭代器;如果为None
,则此方法假定输入为model.parameters()
。此参数已弃用,不再需要传入。(默认值:None
)optim (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是优于
optim_input
的首选参数。(默认值:None
)group (dist.ProcessGroup) – 模型的进程组,如果使用默认进程组,则为
None
。(默认值:None
)
- 返回
完整的优化器状态字典现在已重新映射到展平参数而不是未展平参数,并且仅限于此 rank 的优化器状态部分。
- 返回类型
字典[str, Any]
- static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]#
设置目标模块所有后代 FSDP 模块的
state_dict_type
。还可以选择模型和优化器状态字典的(可选)配置。目标模块不必是 FSDP 模块。如果目标模块是 FSDP 模块,其
state_dict_type
也将被更改。注意
此 API 只应为顶级(根)模块调用。
注意
此 API 允许用户在根 FSDP 模块被另一个
nn.Module
封装的情况下,透明地使用传统的state_dict
API 进行模型检查点。例如,以下代码将确保state_dict
在所有非 FSDP 实例上调用,同时为 FSDP 调度到 sharded_state_dict 实现。示例
>>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
- 参数
module (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 要设置的期望
state_dict_type
。state_dict_config (可选[StateDictConfig]) – 目标
state_dict_type
的配置。optim_state_dict_config (可选[OptimStateDictConfig]) – 优化器状态字典的配置。
- 返回
一个 StateDictSettings,包括模块之前的 state_dict 类型和配置。
- 返回类型
- static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source]#
分片完整的优化器状态字典。
将
full_optim_state_dict
中的状态重新映射为展平参数而不是未展平参数,并仅限于此 rank 的优化器状态部分。第一个参数应为full_optim_state_dict()
的返回值。示例
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
注意
shard_full_optim_state_dict()
和scatter_full_optim_state_dict()
都可以用于获取要加载的分片优化器状态字典。假设完整的优化器状态字典驻留在 CPU 内存中,前者要求每个 rank 在 CPU 内存中拥有完整的字典,其中每个 rank 单独分片字典而无需任何通信,而后者仅要求 rank 0 在 CPU 内存中拥有完整的字典,其中 rank 0 将每个分片移动到 GPU 内存(对于 NCCL)并将其适当地通信到相应的 rank。因此,前者具有更高的总 CPU 内存成本,而后者具有更高的通信成本。- 参数
full_optim_state_dict (字典[str, 任意]) – 优化器状态字典,对应于未展平的参数,并包含完整的非分片优化器状态。
model (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel
实例),其参数与full_optim_state_dict
中的优化器状态相对应。optim_input (可选[联合[列表[字典[str, 任意]], 可迭代[torch.nn.Parameter]]]) – 传入优化器的输入,表示参数组的
list
或参数的迭代器;如果为None
,则此方法假定输入为model.parameters()
。此参数已弃用,不再需要传入。(默认值:None
)optim (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是优于
optim_input
的首选参数。(默认值:None
)
- 返回
完整的优化器状态字典现在已重新映射到展平参数而不是未展平参数,并且仅限于此 rank 的优化器状态部分。
- 返回类型
字典[str, Any]
- static sharded_optim_state_dict(model, optim, group=None)[source]#
以分片形式返回优化器状态字典。
此 API 类似于
full_optim_state_dict()
,但此 API 将所有非零维度状态分块为ShardedTensor
以节省内存。此 API 仅应在模型state_dict
通过上下文管理器with state_dict_type(SHARDED_STATE_DICT):
派生时使用。有关详细用法,请参阅
full_optim_state_dict()
。警告
返回的状态字典包含
ShardedTensor
,不能直接被常规optim.load_state_dict
使用。
- static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]#
设置目标模块所有后代 FSDP 模块的
state_dict_type
。此上下文管理器与
set_state_dict_type()
具有相同的功能。请阅读set_state_dict_type()
的文档以获取详细信息。示例
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict()
- 参数
module (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 要设置的期望
state_dict_type
。state_dict_config (可选[StateDictConfig]) – 目标
state_dict_type
的模型state_dict
配置。optim_state_dict_config (可选[OptimStateDictConfig]) – 目标
state_dict_type
的优化器state_dict
配置。
- 返回类型
- static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source]#
使用此上下文管理器暴露 FSDP 实例的完整参数。
在模型的前向/后向之后,可以用于获取参数以进行额外处理或检查。它可以接受非 FSDP 模块,并将为所有包含的 FSDP 模块及其子模块召唤完整参数,具体取决于
recurse
参数。注意
这可以用于内部 FSDP。
注意
这不能在前向或后向传递中使用。也不能在此上下文内部启动前向和后向。
注意
参数在上下文管理器退出后将恢复为其局部碎片,存储行为与前向相同。
注意
完整的参数可以被修改,但只有与局部参数分片对应的部分会在上下文管理器退出后保留(除非
writeback=False
,在这种情况下更改将被丢弃)。在 FSDP 不分片参数的情况下(目前仅当world_size == 1
或NO_SHARD
配置时),无论writeback
如何,修改都会保留。注意
此方法适用于本身不是 FSDP 但可能包含多个独立 FSDP 单元的模块。在这种情况下,给定的参数将应用于所有包含的 FSDP 单元。
警告
请注意,
rank0_only=True
与writeback=True
结合目前不支持,并将引发错误。这是因为在上下文内部,模型参数形状在 rank 之间会有所不同,并且写入它们可能导致上下文退出时 rank 之间不一致。警告
请注意,
offload_to_cpu
和rank0_only=False
将导致完整参数冗余地复制到同一机器上的 CPU 内存中,这可能会带来 CPU OOM 的风险。建议将offload_to_cpu
与rank0_only=True
结合使用。- 参数
recurse (bool, 可选) – 递归地召唤嵌套 FSDP 实例的所有参数(默认:True)。
writeback (bool, 可选) – 如果为
False
,则在上下文管理器退出后,对参数的修改将被丢弃;禁用此功能可以稍微提高效率(默认:True)rank0_only (bool, 可选) – 如果为
True
,则完整参数仅在全球 rank 0 上具现化。这意味着在上下文内部,只有 rank 0 拥有完整参数,其他 rank 拥有分片参数。请注意,将rank0_only=True
与writeback=True
结合使用不受支持,因为在上下文内部,模型参数形状在 rank 之间会有所不同,写入它们可能导致上下文退出时 rank 之间不一致。offload_to_cpu (bool, 可选) – 如果为
True
,则完整参数将卸载到 CPU。请注意,目前此卸载仅在参数分片时发生(仅在 world_size = 1 或NO_SHARD
配置时才不是这种情况)。建议将offload_to_cpu
与rank0_only=True
结合使用,以避免模型参数的冗余副本卸载到同一 CPU 内存。with_grads (bool, 可选) – 如果为
True
,则梯度也会与参数一起取消分片。目前,这仅在向 FSDP 构造函数传递use_orig_params=True
且向此方法传递offload_to_cpu=False
时受支持。(默认值:False
)
- 返回类型
- class torch.distributed.fsdp.BackwardPrefetch(value)[source]#
这配置了显式后向预取,通过在后向传递中实现通信和计算重叠来提高吞吐量,但代价是内存使用略有增加。
BACKWARD_PRE
:这实现了最大的重叠,但内存使用量增加最多。它在当前参数梯度计算之前预取下一组参数。这重叠了下一次 all-gather 和当前梯度计算,并且在峰值时,它将当前参数集、下一组参数和当前梯度集保存在内存中。BACKWARD_POST
:这实现了较少的重叠,但需要的内存使用量较少。它在当前参数梯度计算之后预取下一组参数。这重叠了当前 reduce-scatter 和下一次梯度计算,并且它在为下一组参数分配内存之前释放当前参数集,在峰值时仅将下一组参数和当前梯度集保存在内存中。FSDP 的
backward_prefetch
参数接受None
,这将完全禁用后向预取。这没有重叠,也不会增加内存使用量。一般来说,我们不推荐此设置,因为它可能会显著降低吞吐量。
更多技术背景:对于使用 NCCL 后端的单个进程组,任何集体操作,即使是从不同的流发出的,也会争用相同的每设备 NCCL 流,这意味着集体操作发出的相对顺序对于重叠很重要。两个后向预取值对应不同的发出顺序。
- class torch.distributed.fsdp.ShardingStrategy(value)[source]#
这指定了
FullyShardedDataParallel
用于分布式训练的分片策略。FULL_SHARD
:参数、梯度和优化器状态都已分片。对于参数,此策略在前向之前进行非分片(通过 all-gather),在前向之后进行重新分片,在后向计算之前进行非分片,在后向计算之后进行重新分片。对于梯度,它在后向计算之后同步并分片它们(通过 reduce-scatter)。分片的优化器状态在每个 rank 上本地更新。SHARD_GRAD_OP
:在计算过程中,梯度和优化器状态会被分片,此外,参数在计算之外也会被分片。对于参数,此策略在前向之前进行解分片,在前向之后不重新分片它们,而只在反向计算之后重新分片它们。分片的优化器状态在每个 rank 上本地更新。在no_sync()
内部,参数在反向计算后不会重新分片。NO_SHARD
:参数、梯度和优化器状态未分片,而是类似于 PyTorch 的DistributedDataParallel
API,在 rank 之间复制。对于梯度,此策略在后向计算之后同步它们(通过 all-reduce)。未分片的优化器状态在每个 rank 上本地更新。HYBRID_SHARD
:在节点内部应用FULL_SHARD
,并在节点之间复制参数。这会减少通信量,因为昂贵的 all-gather 和 reduce-scatter 只在节点内部完成,这对于中型模型来说可能更高效。_HYBRID_SHARD_ZERO2
:在节点内部应用SHARD_GRAD_OP
,并在节点之间复制参数。这类似于HYBRID_SHARD
,只是这可能会提供更高的吞吐量,因为未分片的参数在前向传递后不会被释放,从而节省了后向之前的 all-gather。
- class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source]#
这配置了 FSDP 原生混合精度训练。
- 变量
param_dtype (可选[torch.dtype]) – 这指定了模型参数在前向和后向期间的数据类型,从而指定了前向和后向计算的数据类型。在前向和后向之外,分片参数保持全精度(例如用于优化器步骤),并且对于模型检查点,参数始终以全精度保存。(默认值:
None
)reduce_dtype (可选[torch.dtype]) – 这指定了梯度归约(即 reduce-scatter 或 all-reduce)的数据类型。如果此参数为
None
但param_dtype
不为None
,则此参数将采用param_dtype
的值,梯度归约仍以低精度运行。允许其与param_dtype
不同,例如强制梯度归约以全精度运行。(默认值:None
)buffer_dtype (可选[torch.dtype]) – 这指定了缓冲区的 dtype。FSDP 不分片缓冲区。相反,FSDP 在第一次前向传递中将它们转换为
buffer_dtype
,并此后保持该 dtype。对于模型检查点,缓冲区以全精度保存,除了LOCAL_STATE_DICT
。(默认值:None
)keep_low_precision_grads (bool) – 如果为
False
,则 FSDP 在反向传播后将梯度向上转换为全精度,为优化器步骤做准备。如果为True
,则 FSDP 将梯度保持在用于梯度归约的 dtype 中,如果使用支持低精度运行的自定义优化器,则可以节省内存。(默认值:False
)cast_forward_inputs (bool) – 如果为
True
,则此 FSDP 模块将其前向参数和关键字参数转换为param_dtype
。这是为了确保参数和输入 dtype 匹配前向计算(许多操作都需要)。当仅对部分 FSDP 模块应用混合精度时,这可能需要设置为True
,在这种情况下,混合精度 FSDP 子模块需要重新转换其输入。(默认值:False
)cast_root_forward_inputs (bool) – 如果为
True
,则根 FSDP 模块将其前向参数和关键字参数转换为param_dtype
,覆盖cast_forward_inputs
的值。对于非根 FSDP 模块,这不起作用。(默认值:True
)_module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): 当使用
auto_wrap_policy
时,此参数指定在混合精度模式下要忽略的模块类:这些类的模块将单独应用 FSDP,并禁用混合精度(这意味着最终的 FSDP 构造将偏离指定的策略)。如果未指定auto_wrap_policy
,则此参数不执行任何操作。此 API 处于实验阶段,可能会发生变化。(默认值:(_BatchNorm,)
)
注意
此 API 处于实验阶段,可能会发生变化。
注意
只有浮点张量会被转换为其指定的 dtype。
注意
在
summon_full_params
中,参数被强制为全精度,但缓冲区不是。注意
即使层归一化和批归一化层的输入精度较低,例如
float16
或bfloat16
,它们也会以float32
精度累积。为这些归一化模块禁用 FSDP 的混合精度,仅意味着仿射参数保持float32
精度。然而,这会导致这些归一化模块单独进行 all-gather 和 reduce-scatter 操作,这可能效率低下,因此如果工作负载允许,用户应该仍然对这些模块应用混合精度。注意
默认情况下,如果用户传递的模型包含任何
_BatchNorm
模块并指定了auto_wrap_policy
,则批归一化模块将单独应用 FSDP,并禁用混合精度。请参阅_module_classes_to_ignore
参数。注意
MixedPrecision
默认情况下将cast_root_forward_inputs=True
和cast_forward_inputs=False
。对于根 FSDP 实例,其cast_root_forward_inputs
优先于其cast_forward_inputs
。对于非根 FSDP 实例,其cast_root_forward_inputs
值将被忽略。默认设置对于典型情况是足够的,即每个 FSDP 实例具有相同的MixedPrecision
配置,并且只需要在模型前向传播开始时将输入转换为param_dtype
。注意
对于具有不同
MixedPrecision
配置的嵌套 FSDP 实例,我们建议设置单独的cast_forward_inputs
值,以配置每个实例前向传播之前是否进行类型转换。在这种情况下,由于类型转换发生在每个 FSDP 实例前向传播之前,因此父 FSDP 实例应该在其 FSDP 子模块之前运行其非 FSDP 子模块,以避免激活数据类型因不同的MixedPrecision
配置而改变。示例
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> )
以上显示了一个可行的示例。另一方面,如果将
model[1]
替换为model[0]
,这意味着使用不同MixedPrecision
的子模块首先执行其前向传播,那么model[1]
将错误地看到float16
激活而不是bfloat16
激活。
- class torch.distributed.fsdp.CPUOffload(offload_params=False)[source]#
此配置用于 CPU 卸载。
- 变量
offload_params (bool) – 此参数指定是否在不涉及计算时将参数卸载到 CPU。如果为
True
,则也会将梯度卸载到 CPU,这意味着优化器步骤在 CPU 上运行。
- class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source]#
StateDictConfig
是所有state_dict
配置类的基类。用户应实例化一个子类(例如FullStateDictConfig
),以便为 FSDP 支持的相应state_dict
类型配置设置。- 变量
offload_to_cpu (bool) – 如果为
True
,则 FSDP 将状态字典值卸载到 CPU;如果为False
,则 FSDP 将它们保留在 GPU 上。(默认值:False
)
- class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[source]#
FullStateDictConfig
是一个配置类,旨在与StateDictType.FULL_STATE_DICT
结合使用。我们建议在保存完整状态字典时同时启用offload_to_cpu=True
和rank0_only=True
,分别以节省 GPU 内存和 CPU 内存。此配置类旨在通过state_dict_type()
上下文管理器使用,如下所示>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> fsdp = FSDP(model, auto_wrap_policy=...) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP >>> if dist.get_rank() == 0: >>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument >>> # communicates loaded checkpoint states from rank 0 to rest of the world. >>> fsdp = FSDP( ... model, ... device_id=torch.cuda.current_device(), ... auto_wrap_policy=..., ... sync_module_states=True, ... ) >>> # After this point, all ranks have FSDP model with loaded checkpoint.
- 变量
rank0_only (bool) – 如果为
True
,则只有 rank 0 保存完整的状态字典,非零 rank 保存空字典。如果为False
,则所有 rank 都保存完整的状态字典。(默认值:False
)
- class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[source]#
ShardedStateDictConfig
是一个配置类,旨在与StateDictType.SHARDED_STATE_DICT
结合使用。- 变量
_use_dtensor (bool) – 如果为
True
,则 FSDP 将状态字典值保存为DTensor
;如果为False
,则 FSDP 将它们保存为ShardedTensor
。(默认值:False
)
警告
_use_dtensor
是ShardedStateDictConfig
的私有字段,FSDP 使用它来确定状态字典值的类型。用户不应手动修改_use_dtensor
。
- class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source]#
OptimStateDictConfig
是所有optim_state_dict
配置类的基类。用户应实例化一个子类(例如FullOptimStateDictConfig
),以便为 FSDP 支持的相应optim_state_dict
类型配置设置。- 变量
offload_to_cpu (bool) – 如果为
True
,则 FSDP 将状态字典的张量值卸载到 CPU;如果为False
,则 FSDP 将它们保留在原始设备上(除非启用了参数 CPU 卸载,否则为 GPU)。(默认值:True
)
- class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[source]#
- 变量
rank0_only (bool) – 如果为
True
,则只有 rank 0 保存完整的状态字典,非零 rank 保存空字典。如果为False
,则所有 rank 都保存完整的状态字典。(默认值:False
)
- class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[source]#
ShardedOptimStateDictConfig
是一个配置类,旨在与StateDictType.SHARDED_STATE_DICT
结合使用。- 变量
_use_dtensor (bool) – 如果为
True
,则 FSDP 将状态字典值保存为DTensor
;如果为False
,则 FSDP 将它们保存为ShardedTensor
。(默认值:False
)
警告
_use_dtensor
是ShardedOptimStateDictConfig
的私有字段,FSDP 使用它来确定状态字典值的类型。用户不应手动修改_use_dtensor
。
- class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[source]#