分布式检查点 - torch.distributed.checkpoint#
创建于: 2022年11月16日 | 最后更新于: 2025年9月4日
分布式检查点 (DCP) 支持并行地从多个 rank 加载和保存模型。它处理加载时的重分片,这使得可以在一种集群拓扑中保存,而在另一种集群拓扑中加载。
DCP 在几个重要方面与 torch.save
和 torch.load
不同
它为每个检查点生成多个文件,每个 rank 至少一个。
它会就地操作,这意味着模型应先分配其数据,DCP 使用该存储空间。
加载和保存检查点的入口点如下:
其他资源:#
- class torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType(value)[source]#
异步检查点类型的枚举。
- class torch.distributed.checkpoint.state_dict_saver.AsyncSaveResponse(staging_completion, upload_completion)[source]#
此类包含暂存和上传完成的 Future。它由 async_save() 返回。staging_completion 是一个表示 state_dict 本地副本完成的 Future。upload_completion 是一个表示检查点保存完成的 Future。
- torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, no_dist=False, use_collectives=True)[source]#
以 SPMD 风格保存分布式模型。
此函数与
torch.save()
不同,因为它会处理ShardedTensor
和DTensor
,让每个 rank 只保存其本地分片。对于每个
Stateful
对象(同时具有state_dict
和load_state_dict
),save 会在序列化之前调用state_dict
。警告
PyTorch 版本之间保存的 state_dict 不保证向后兼容。
警告
如果使用 process_group 参数,请确保只有其 rank 调用 save_state_dict,并且 state_dict 中的所有数据都属于它。
注意
当为 FSDP 的 ShardingStrategy.HYBRID_SHARD 保存检查点时,只有其中一个 shard_group 应该调用 save_state_dict,并且需要传入相应的进程组。
注意
- 如果没有可用的进程组,此函数将假定意图是保存
本地进程中的 state_dict。
- 参数
state_dict (Dict[str, Any]) – 要保存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是键。(默认:
None
)storage_writer (Optional[StorageWriter]) – 用于执行写入的 StorageWriter 实例。如果未指定此项,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则将引发异常。(默认:
None
)planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定此项,将使用默认规划器。(默认:
None
)process_group (Optional[ProcessGroup]) – 用于跨 rank 同步的 ProcessGroup。(默认:
None
)no_dist (bool) – 如果为
True
,此函数将假定意图是在单个 rank/进程上加载检查点。(默认:False
)use_collectives (bool) – 如果为
False
,此函数将假定意图是保存检查点而不使用跨 rank 同步。(默认:True
) 此配置是实验性的,应谨慎使用。它将更改保存的检查点格式,并且可能不向后兼容。
- 返回
保存的检查点的元数据对象。
- 返回类型
元数据
示例
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( ... "/checkpoint/1" ... ) >>> torch.distributed.checkpoint.save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> )
注意
save_state_dict 使用集体通信来协调 rank 之间的写入。对于基于 NCCL 的进程组,对象内部的张量表示必须在通信发生之前移动到 GPU 设备。在这种情况下,使用的设备由
torch.cuda.current_device()
指定,用户有责任确保通过torch.cuda.set_device()
设置每个 rank 拥有独立的 GPU。
- torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, async_checkpointer_type=AsyncCheckpointerType.THREAD, async_stager=None, no_dist=False, use_collectives=True)[source]#
异步版本的
save
。此代码首先将 state_dict 暂存到暂存存储(默认为 CPU 内存),然后在一个单独的线程中调用 save。警告
此功能是实验性的,可能会发生更改。必须在最后一个检查点保存后调用 CLOSE。
- 参数
state_dict (Dict[str, Any]) – 要保存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是键。(默认:
None
)storage_writer (Optional[StorageWriter]) – 用于执行“暂存”和“保存”的 StorageWriter 实例。如果未指定此项,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则将引发异常。(默认:
None
)planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定此项,将使用默认规划器。(默认:
None
)process_group (Optional[ProcessGroup]) – 用于跨 rank 同步的 ProcessGroup。(默认:
None
)async_checkpointer_type (AsyncCheckpointerType) – 是在单独的线程还是进程中执行检查点(默认:
AsyncCheckpointerType.THREAD
)async_stager (AsyncStager) – 提供暂存实现。如果 storage_writer 实现 AsyncStager 并且提供了 async_stager,则将使用 async_stager 进行暂存。
no_dist (bool) – 如果为
True
,此函数将假定意图是保存单个 rank/进程的检查点。(默认:False
)use_collectives (bool) – 如果为 False,则不带 rank 协调地保存检查点。(默认:
True
) 此配置是实验性的,应谨慎使用。它将更改保存的检查点格式,并且可能不向后兼容。
- 返回
包含 save 返回的元数据对象的 Future。
- 返回类型
示例
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( ... "/checkpoint/1" ... ) >>> checkpoint_future = torch.distributed.checkpoint.async_save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> ) >>> >>> # ... do some work ... >>> >>> checkpoint_future.result()
- torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]#
此方法已弃用。请切换到“save”。
- 返回类型
元数据
- torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None, no_dist=False)[source]#
以 SPMD 风格将检查点加载到分布式 state_dict 中。
每个 rank 提供的
state_dict
必须具有相同的键。键不匹配可能导致挂起或错误。如果不确定,可以使用utils._assert_same_keys
API 进行检查(但可能会产生通信开销)。每个 rank 将尝试读取满足所请求 state_dict 所需的最少数据量。加载
ShardedTensor
或DTensor
实例时,每个 rank 只读取其本地分片的数据。对于每个
Stateful
对象(同时具有state_dict
和load_state_dict
),load 将在反序列化之前调用state_dict
,然后在反序列化完成后调用load_state_dict
。对于每个非Stateful
对象,load 将反序列化该对象,然后用反序列化的对象替换state_dict
中的对象。警告
state_dict
中的所有张量必须在调用此函数*之前*已分配到其目标设备上。所有非张量数据使用 torch.load() 加载,并在 state_dict 上就地修改。
警告
用户必须在根模块上调用 load_state_dict,以确保加载后的处理和非张量数据能够正确传播。
- 参数
state_dict (Dict[str, Any]) – 要将检查点加载到的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是键。(默认:
None
)storage_reader (Optional[StorageReader]) – 用于执行读取的 StorageWriter 实例。如果未指定此项,DCP 将根据 checkpoint_id 自动推断读取器。如果 checkpoint_id 也为 None,则将引发异常。(默认:
None
)planner (Optional[LoadPlanner]) – LoadPlanner 实例。如果未指定此项,将使用默认规划器。(默认:
None
)process_group (Optional[ProcessGroup]) – 用于跨 rank 同步的 ProcessGroup。(默认:
None
)no_dist (bool) – 如果为
True
,此函数将假定意图是加载检查点而不使用跨 rank 同步。(默认:False
)
- 返回
无。
- 返回类型
无
- 示例
>>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader( ... "/checkpoint/1" ... )
>>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> )
>>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict)
注意
load_state_dict 使用集体通信来协调 rank 之间的读取。对于基于 NCCL 的进程组,对象内部的张量表示必须在通信发生之前移动到 GPU 设备。在这种情况下,使用的设备由
torch.cuda.current_device()
指定,用户有责任确保通过torch.cuda.set_device()
设置每个 rank 拥有独立的 GPU。
- torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]#
此方法已弃用。请切换到“load”。
以下模块对于异步检查点(torch.distributed.checkpoint.async_save
)所使用的暂存机制的额外自定义也很有用。
- class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)[source]#
此协议旨在为 dcp.async_save 提供自定义和可扩展性,允许用户自定义在执行常规 dcp.save 路径之前如何暂存数据。预期的操作顺序(在 torch.distributed.state_dict_saver.async_save 中具体定义)如下:
- AsyncStager.stage_data(state_dict)
此调用让 AsyncStager 有机会“暂存”state_dict。在此上下文中暂存的期望和目的是创建 state dict 的“训练安全”表示,这意味着在暂存完成后对模块数据进行的任何更新都不应反映在此方法返回的 state dict 中。例如,在默认情况下,会创建整个 state dict 的副本并将其置于 CPU RAM 中返回,从而允许用户继续训练而不会冒着修改正在序列化的数据的风险。
- dcp.save 会在暂存后返回的 state_dict 上并行调用。此调用负责
序列化 state_dict 并将其写入存储。
- 如果 AsyncStager.should_synchronize_after_execute 为 True,则将在
序列化线程启动后立即调用此方法,并在从 dcp.async_save 返回之前调用。如果设置为 False,则假设用户已定义自定义同步点以进一步优化训练循环中的保存延迟(例如,通过将暂存与前向/后向传递重叠),并且用户有责任在适当的时间调用 AsyncStager.synchronize_staging。
- class torch.distributed.checkpoint.staging.DefaultStager(config=StagingOptions(use_pinned_memory=True, use_shared_memory=True, use_async_staging=True, use_non_blocking_copy=True))[source]#
DefaultStager 提供了一个功能齐全的暂存实现,它结合了多种优化技术,用于高效的检查点准备。
暂存过程如下:1. 提交 state dictionary 进行暂存(同步或异步)2. 将张量从 GPU 复制到优化的 CPU 存储 3. 如果使用非阻塞复制,则同步 CUDA 操作 4. 返回暂存的 state dictionary 或通过 Future 提供。
- 使用模式
# 同步暂存 stager = DefaultStager(StagingOptions(use_async_staging=False)) staged_dict = stager.stage(state_dict) stager.close()
# 异步暂存 stager = DefaultStager(StagingOptions(use_async_staging=True)) future = stager.stage(state_dict) # ... 执行其他工作 ... staged_dict = future.result() stager.close()
# 上下文管理器模式(推荐) stager = DefaultStager(config) with stager: result = stager.stage(state_dict)
- 性能考虑
当模型计算可以与暂存操作重叠时,异步暂存提供最佳性能。
固定内存可提高 CPU-GPU 传输速度,但会消耗更多内存。
共享内存允许与检查点进程进行高效的 IPC。
非阻塞复制减少了 GPU 在内存传输期间的空闲时间。
- 线程安全
DefaultStager 不是线程安全的。每个线程应使用自己的实例,或提供外部同步。
- close()[source]#
清理 DefaultStager 使用的所有资源。关闭用于异步暂存操作的 ThreadPoolExecutor,并清理底层的 StateDictStager 的缓存存储。应在不再需要暂存器时调用它,以防止资源泄漏,尤其是在长时间运行的应用程序中。调用 close() 后,不应再使用暂存器进行进一步的暂存操作。
- 示例用法
stager = DefaultStager(StagingOptions(use_async_staging=True)) future = stager.stage(state_dict) result = future.result() stager.close() # 清理所有资源
- class torch.distributed.checkpoint.staging.StagingOptions(use_pinned_memory=True, use_shared_memory=True, use_async_staging=True, use_non_blocking_copy=True)[source]#
检查点暂存行为的配置选项。
- 变量
注意
如果 CUDA 不可用,依赖 CUDA 的功能将引发异常。
- class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)[source]#
AsyncStager 的一个实现,它将 state_dict 暂存到 CPU RAM 并阻塞直到复制完成。此实现还提供了一个选项来使用固定内存优化暂存延迟。
注意:在这种情况下,synchronize_staging 是一个 no-op。
除了上述入口点外,Stateful
对象(如下所述)在保存/加载期间提供了额外的自定义。
- class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[source]#
可检查点和恢复的对象的 Stateful 协议。
此 示例 显示了如何使用 Pytorch Distributed Checkpoint 保存 FSDP 模型。
以下类型定义了检查点期间使用的 IO 接口。
- class torch.distributed.checkpoint.StorageReader[source]#
load_state_dict
用于从存储读取的接口。一个 StorageReader 实例在分布式检查点中同时充当协调器和跟随者。在初始化时,每个实例都被告知其角色。
子类应期望
load_state_dict
调用以下序列:(所有 rank)如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。
(所有 rank)read_metadata()
(所有 rank)set_up_storage_reader()
(所有 rank)prepare_local_plan()
(协调器)prepare_global_plan()
(所有 rank)read_data()
- abstract prepare_global_plan(plans)[source]#
执行存储加载的集中式规划。
此方法仅在协调器实例上调用。
虽然此方法可以产生一个完全不同的计划,但首选方式是将存储特定数据存储在 LoadPlan::storage_data 中。
- 参数
plans (list[torch.distributed.checkpoint.planner.LoadPlan]) –
LoadPlan
实例的列表,每个 rank 一个。- 返回
存储全局规划后的转换后的
LoadPlan
列表。- 返回类型
- abstract prepare_local_plan(plan)[source]#
执行存储特定的本地规划。
虽然此方法可以产生一个完全不同的计划,但推荐的方式是将存储特定数据存储在 LoadPlan::storage_data 中。
- abstract read_data(plan, planner)[source]#
使用
planner
读取plan
中的所有项以解析数据。子类应调用
LoadPlanner::load_bytes
将 BytesIO 对象反序列化到正确的位置。子类应调用
LoadPlanner::resolve_tensor
来访问它们应该加载数据的张量。存储层负责正确调度任何所需的跨设备复制。
- 参数
plan (LoadPlan) – 要执行的本地计划。
planner (LoadPlanner) – 用于解析项的规划器对象。
- 返回
在所有读取完成后完成的 Future。
- 返回类型
Future[None]
- abstract reset(checkpoint_id=None)[source]#
调用以指示将要发生全新的检查点读取。如果用户为此检查点读取设置了 checkpoint_id,则可能存在 checkpoint_id。checkpoint_id 的含义取决于存储。它可以是文件夹/文件的路径,也可以是键值存储的键。
- 参数
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储更像键值存储,它也可以是键。(默认:
None
)
- class torch.distributed.checkpoint.StorageWriter[source]#
save_state_dict
用于写入存储的接口。一个 StorageWriter 实例在分布式检查点中同时充当协调器和跟随者。在初始化时,每个实例都被告知其角色。
子类应期望以下调用顺序。
(所有 rank)如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。
(所有 rank)set_up_storage_writer()
(所有 rank)prepare_local_plan()
(协调器)prepare_global_plan()
(所有 rank)write_data()
(协调器)finish()
- abstract finish(metadata, results)[source]#
写入元数据并标记当前检查点为成功。
用于序列化 metadata 的实际格式/模式是实现细节。唯一的要求是它可以恢复到相同的对象图。
- abstract prepare_global_plan(plans)[source]#
执行存储的集中式规划。
此方法仅在协调器实例上调用。
虽然此方法可以产生一个完全不同的计划,但首选方式是将存储特定数据存储在 SavePlan::storage_data 中。
- 参数
plans (list[torch.distributed.checkpoint.planner.SavePlan]) –
SavePlan
实例的列表,每个 rank 一个。- 返回
存储全局规划后的转换后的
SavePlan
列表。- 返回类型
- abstract prepare_local_plan(plan)[source]#
执行存储特定的本地规划。
虽然此方法可以产生一个完全不同的计划,但推荐的方式是将存储特定数据存储在 SavePlan::storage_data 中。
- abstract reset(checkpoint_id=None)[source]#
调用以指示将要发生全新的检查点写入。如果用户为此检查点写入设置了 checkpoint_id,则可能存在 checkpoint_id。checkpoint_id 的含义取决于存储。它可以是文件夹/文件的路径,也可以是键值存储的键。
- 参数
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是键。(默认:
None
)
- abstract set_up_storage_writer(is_coordinator, *args, **kwargs)[source]#
初始化此实例。
- 参数
is_coordinator (bool) – 此实例是否负责协调检查点。
- storage_meta()[source]#
返回存储特定的元数据。这用于在检查点中存储额外信息,这些信息对于提供请求级别的可观察性很有用。StorageMeta 在保存调用期间传递给
SavePlanner
。默认返回 None。TODO: 提供一个示例
- 返回类型
Optional[StorageMeta]
- abstract classmethod validate_checkpoint_id(checkpoint_id)[source]#
检查给定的 checkpoint_id 是否受存储支持。这使我们能够启用自动存储选择。
- 返回类型
- abstract write_data(plan, planner)[source]#
使用
planner
解析数据,将plan
中的所有项写入存储。子类应在计划中的每个项上调用
SavePlanner::resolve_data
以获取要写入的数据的底层访问权限。子类应延迟调用 resolve_data,因为它可能分配内存。对于张量,请做以下假设:
它们可能位于任何设备上,包括与
WriteItem::tensor_data
上不匹配的设备。它们可能是视图或不连续的。只需要保存投影。
- 参数
plan (SavePlan) – 要执行的保存计划。
planner (SavePlanner) – 用于将项解析为数据的规划器对象。
- 返回
一个 Future,完成后会得到一个 WriteResult 列表。
- 返回类型
Future[list[torch.distributed.checkpoint.storage.WriteResult]]
以下类型定义了检查点期间使用的规划器接口。
- class torch.distributed.checkpoint.LoadPlanner[source]#
定义 load_state_dict 用于规划加载过程的协议的抽象类。
LoadPlanner 是有状态对象,可用于自定义整个加载过程。
LoadPlanner 作为 state_dict 的访问代理,因此对其进行的任何转换都将对整个进程可见。
规划器子类可以在 load_state_dict 期间预期以下调用序列:
- set_up_planner - 在所有 rank 上调用。
指示检查点加载的开始。
- create_local_plan - 在所有 rank 上调用。
处理 state_dict 并生成将用于全局规划的 LoadPlan。
- create_global_plan - 仅在协调器 rank 上调用。
获取所有 rank 的 LoadPlan 并做出任何全局决策。
- load_bytes - 在每个 rank 上多次调用。
这在 state_dict 的每个非张量值调用一次。
- resolve_tensor 和 commit_tensor - 在每个 rank 上多次调用。
它们成对调用,用于 state_dict 中的每个张量值。
建议用户扩展 DefaultLoadPlanner 而不是直接扩展此接口,因为大多数更改都可以通过修改单个方法来表达。
有两种常见的扩展模式:
重写 state_dict。这是扩展加载过程的最简单方法,因为它不需要理解 LoadPlan 的复杂性。我们需要保留对原始 state_dict 的引用,因为加载是就地进行的,因此我们需要能够就地执行它。
>>> class RenamePlanner(DefaultLoadPlanner): >>> def set_up_planner( >>> self, >>> state_dict: STATE_DICT_TYPE, >>> metadata: Metadata, >>> is_coordinator: bool, >>> ) -> None: >>> self.original_state_dict = state_dict >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} >>> >>> if self.flatten_sharded_tensors: >>> state_dict = _flatten_sharded_tensors(state_dict) >>> >>> if self.flatten_state_dict: >>> state_dict, self.mappings = flatten_state_dict(state_dict) >>> >>> self.state_dict = state_dict >>> self.metadata = metadata >>> self.is_coordinator = is_coordinator >>> >>> def load_bytes(self, read_item, value): >>> # Remove the "foo_" prefix >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
修改 resolve_tensor 和 commit_tensor 以处理加载时转换。
>>> class MetaModelMaterialize(DefaultSavePlanner): >>> def resolve_tensor(self, read_item): >>> tensor = super().resolve_tensor(read_item) >>> return torch.empty_like(tensor, device="cpu") >>> >>> def commit_tensor(self, read_item, tensor): >>> self.state_dict[read_item.dest_index.fqn] = tensor
- abstract commit_tensor(read_item, tensor)[source]#
在 StorageReader 完成将数据加载到
tensor
中后调用一次。提供的张量与
resolve_tensor
的调用返回的张量相同。如果此 LoadPlanner 需要在将张量复制回 state_dict 中的张量之前进行后处理,则需要此方法。张量的内容将遵循其设备同步模型。
- abstract create_local_plan()[source]#
根据 set_up_planner 提供的 state_dict 和元数据创建 LoadPlan。
. 注意:这在每个 rank 上都调用。
- 返回类型
- abstract load_bytes(read_item, value)[source]#
加载
read_item``和 ``value
描述的项目。此方法应就地修改底层 state_dict。
value
的内容由用于生成正在加载的检查点的 SavePlanner 定义。
- resolve_bytes(read_item)[source]#
返回 StorageReader 用于加载 read_item 的 BytesIO。
BytesIO 应与底层 state_dict 上的 BytesIO 别名,因为 StorageReader 将替换其内容。
- 返回类型
BytesIO
- class torch.distributed.checkpoint.LoadPlan(items: list[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source]#
- class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source]#
- class torch.distributed.checkpoint.SavePlanner[source]#
定义 save_state_dict 用于规划保存过程的协议的抽象类。
SavePlanner 是有状态对象,可用于自定义整个保存过程。
SavePlanner 作为 state_dict 的访问代理,因此对其进行的任何转换都将对整个进程可见。
规划器子类可以在 save_state_dict 期间预期以下调用序列:
- set_up_planner - 在所有 rank 上调用。
指示检查点保存的开始。
- create_local_plan - 在所有 rank 上调用。
处理 state_dict 并生成将用于全局规划的 SavePlan。
- create_global_plan - 仅在协调器 rank 上调用。
获取所有 rank 的 SavePlan 并做出任何全局决策。
- finish_plan - 在所有 rank 上调用。
这使得每个 rank 都有机会适应全局规划决策。
- resolve_data - 在每个 rank 上多次调用。
查找 state_dict 上的值,供存储层写入。
建议用户扩展 DefaultSavePlanner 而不是直接扩展此接口,因为大多数更改都可以通过修改单个方法来表达。
有 3 种常用的扩展模式:
重写 state_dict。这是扩展保存过程的最简单方法,因为它不需要理解 SavePlan 的复杂性。
>>> class RenamePlanner(DefaultSavePlanner): >>> def set_up_planner( >>> self, >>> state_dict: STATE_DICT_TYPE, >>> storage_meta: Optional[StorageMeta], >>> is_coordinator: bool, >>> ) -> None: >>> # prefix all keys with `foo_`` >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
修改本地计划和查找以协同工作。这对于如何持久化数据进行精细控制很有用。
>>> class FP16Planner(DefaultSavePlanner): >>> def create_local_plan(self): >>> plan = super().create_local_plan() >>> for p in plan: >>> if p.tensor_data is not None: >>> p.tensor_data.properties.dtype = torch.float16 >>> return plan >>> >>> def resolve_data(self, write_item): >>> item = super().resolve_data(write_item) >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
使用全局规划步骤来做出每个 rank 无法单独做出的中央决策。
>>> from itertools import zip_longest >>> from dataclasses import replace >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 >>> # This sample doesn't handle ShardedTensors >>> def create_global_plan(self, all_plans): >>> iters = [iter(all_plans[0].items)] * len(all_plans) >>> items_per_rank = [ >>> [item for item in items if item is not None] >>> for items in zip(*zip_longest(*iters), strict=True) >>> ] >>> all_plans = [ >>> replace(plan, items=items) >>> for plan, items in zip(all_plans, items_per_rank, strict=True) >>> ] >>> return super().create_global_plan(all_plans)
最后,一些规划器需要将其他元数据保存在检查点中,这通过让每个 rank 在本地计划中贡献其数据项,然后由全局规划器进行聚合来实现。
>>> class SaveExtraDataPlanner(DefaultSavePlanner): >>> def create_local_plan(self) -> SavePlan: >>> plan = super().create_local_plan() >>> return replace(plan, planner_data="per-rank-data") >>> >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: >>> global_plan, metadata = super().create_global_plan(all_plans) >>> merged_data = [p.planner_data for p in global_plan] >>> metadata = replace(metadata, planner_data=merged_data) >>> return global_plan, metadata
- abstract create_global_plan(all_plans)[source]#
计算全局检查点计划并返回每个 rank 的本地计划。
这仅在协调器 rank 上调用。
- 返回类型
tuple[list[torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata]
- abstract create_local_plan()[source]#
计算当前 rank 的保存计划。
这将被聚合并传递给 `create_global_plan`。计划特定的数据可以通过 `SavePlan::planner_data` 传递。
此方法会在所有 rank 上调用。
- 返回类型
- abstract finish_plan(new_plan)[source]#
合并由 `create_local_plan` 创建的计划和 `create_global_plan` 的结果。
此方法会在所有 rank 上调用。
- 返回类型
- abstract resolve_data(write_item)[source]#
转换并准备 `state_dict` 中用于存储的 `write_item`,确保幂等性和线程安全。
在 `state_dict` 中查找与 `write_item` 关联的对象,并在存储层使用它之前应用任何转换(例如序列化)。
此方法会在每个 rank 上多次调用,对于最终 `SavePlan` 中的每个 `WriteItem` 至少调用一次。
此方法应该是幂等且线程安全的。`StorageWriter` 实现可以根据需要频繁调用它。
任何会分配内存的转换都应在该方法被调用时惰性执行,以减少 checkpointing 所需的峰值内存。
返回张量时,它们可以位于任何设备或格式,也可以是视图。存储层负责确定如何保存它们。
- class torch.distributed.checkpoint.SavePlan(items: list[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None, usable: bool = True)[source]#
- class torch.distributed.checkpoint.planner.WriteItem(index, type, bytes_io_data=None, tensor_data=None)[source]#
保存有关需要写入存储的数据信息的类。
我们提供一个基于文件系统的存储层。
- class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True, _extensions=None, serialization_format=SerializationFormat.TORCH_SAVE)[source]#
使用文件 IO 的 `StorageWriter` 的基本实现。
此实现进行了以下假设和简化:
`checkpoint` 路径是一个空目录或不存在的目录。
文件创建是原子的。
checkpoint 由每个写入请求一个文件组成,外加一个全局 `.metadata` 文件(如果启用了 rank 协调)。如果未启用 rank 协调,则会有一个 rank 本地的 `__rank.metadata` 文件。
我们也提供其他存储层,包括与 HuggingFace safetensors 交互的存储层。
.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageReader :members
.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageWriter :members
.. autoclass:: torch.distributed.checkpoint.QuantizedHuggingFaceStorageReader :members
我们提供了 `LoadPlanner` 和 `SavePlanner` 的默认实现,它们可以处理所有 torch.distributed 的构造,例如 FSDP、DDP、ShardedTensor 和 DistributedTensor。
- class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False, enable_plan_caching=False)[source]#
- class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source]#
为 `LoadPlanner` 提供了额外的功能。
特别是,它增加了以下功能:
`flatten_state_dict`: 处理嵌套字典的 `state_dict`。 `flatten_sharded_tensors`: 对于 2D 并行模式下的 FSDP,`allow_partial_load`: 如果为 False,则当 `state_dict` 中存在某个键但在 checkpoint 中不存在时,会引发运行时错误。
由于历史设计决策,即使原始未并行化的模型相同,`FSDP` 和 `DDP` 的状态字典也可能具有不同的键或完全限定名(例如,`layer1.weight`)。此外,`FSDP` 提供了各种模型状态字典类型,如完整和分片状态字典。此外,优化器状态字典使用参数 ID 而不是完全限定名来标识参数,这在使用并行性(例如,流水线并行)时可能会导致问题。
为了解决这些挑战,我们提供了一系列 API,用户可以轻松管理 `state_dict`。`get_model_state_dict()` 返回一个模型状态字典,其键与未并行化的模型状态字典返回的键一致。类似地,`get_optimizer_state_dict()` 提供了一个优化器状态字典,其键在所有应用的并行性之间保持一致。为了实现这种一致性,`get_optimizer_state_dict()` 将参数 ID 转换为与未并行化的模型状态字典中的参数相同的完全限定名。
请注意,这些 API 返回的结果可以直接与 `torch.distributed.checkpoint.save()` 和 `torch.distributed.checkpoint.load()` 方法一起使用,无需任何额外的转换。
`set_model_state_dict()` 和 `set_optimizer_state_dict()` 用于加载由其各自的 getter API 生成的模型和优化器 `state_dict`。
请注意,`set_optimizer_state_dict()` 只能在调用优化器上的 `backward()` 之前或 `step()` 之后调用。
请注意,此功能是实验性的,API 签名在将来可能会发生变化。
- torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source]#
返回模型 `state_dict` 和优化器 `state_dict`。
`get_state_dict` 可以处理由 PyTorch FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 并行化的任何模块,以及这些并行性的任何组合。`get_state_dict` 的主要功能是:1)返回一个模型和优化器 `state_dict`,该 `state_dict` 可以用不同数量的训练器和/或不同的并行性进行重分片。2)隐藏特定于并行性的 `state_dict` API。用户无需调用这些 API。3)对结果 `state_dict` 进行健全性检查。
结果 `state_dict` 的键是规范的 FQN(完全限定名)。规范 FQN 指的是基于参数在 `nn.Module` 层次结构中的位置的 FQN。更具体地说,到参数的规范 FQN 是当模块未被任何并行性分布式时,由 `module.named_parameters()` 或 `module.named_buffers()` 返回的 FQN。由于优化器内部使用参数 ID 来表示参数,因此在调用此 API 时会进行从参数 ID 到规范 FQN 的转换。
`get_state_dict` 也可以处理未并行化的模块。在这种情况下,`get_state_dict` 只执行一项功能——将优化器参数 ID 转换为规范 FQN。
示例
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model)) >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_model = DDP(copy.deepcopy(model)) >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict( ... fsdp_model, fsdp_optim ... )
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # the asserts will fail. >>> assert ddp_state_dict == fsdp_state_dict >>> assert ddp_optim_state == fsdp_optim_state_dict
- 参数
model (nn.Module) – 要处理的模型 `nn.Module`。
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – 用于优化 `model` 的优化器。
submodules (deprecated) – Optional[set[nn.Module]]: 仅返回属于子模块的模型参数。
options (StateDictOptions) – 控制如何返回模型 `state_dict` 和优化器 `state_dict` 的选项。有关详细信息,请参阅 `StateDictOptions`。
- 返回
包含模型 `state_dict` 和优化器 `state_dict` 的 `Tuple`。
- 返回类型
- torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source]#
返回 `model` 的模型 `state_dict`。
有关详细用法,请参阅 `get_state_dict`。
- 参数
model (nn.Module) – 要处理的模型 `nn.Module`。
submodules (deprecated) – Optional[set[nn.Module]]: 仅返回属于子模块的模型参数。
options (StateDictOptions) – 控制如何返回模型 `state_dict` 和优化器 `state_dict` 的选项。有关详细信息,请参阅 `StateDictOptions`。
- 返回
`model` 的 `state_dict`。
- 返回类型
- torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source]#
返回合并后的优化器 `state_dict`。
有关详细用法,请参阅 `get_state_dict`。
- 参数
model (nn.Module) – 要处理的模型 `nn.Module`。
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – 用于优化 `model` 的优化器。
submodules (deprecated) – Optional[set[nn.Module]]: 仅返回属于子模块的模型参数。
options (StateDictOptions) – 控制如何返回模型 `state_dict` 和优化器 `state_dict` 的选项。有关详细信息,请参阅 `StateDictOptions`。
- 返回
`optimizers` 的 `state_dict`。
- 返回类型
OptimizerStateType
- torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source]#
加载模型 `state_dict` 和优化器 `state_dict`。
与 `get_state_dict` 相对应的函数,用于将 `state_dict` 设置到模型和优化器。给定的 `model_state_dict` 和 `optim_state_dict` 不必由 `get_state_dict` 返回,但必须满足以下要求:1)所有 FQN 都是 `get_state_dict` 中定义的规范 FQN,2)如果张量被分片,它必须是 ShardedTensor 或 DTensor,3)优化器 `state_dict` 不能包含参数 ID;键应该是规范 FQN。
- 警告:`set_state_dict` 只能在调用优化器上的 `backward()` 之前或 `step()` 之后调用。
否则,优化器状态将无法正确初始化。
- 参数
model (nn.Module) – 要处理的模型 `nn.Module`。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化 `model` 的优化器。
model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 要加载的模型 `state_dict`。如果 `model_state_dict` 的键是 `nn.Module`,则键是 `model` 的一个子模块,值应该是该子模块的 `state_dict`。加载 `state_dict` 时,子模块的路径会附加到 `state_dict`。
optim_state_dict (OptimizerStateType) – OptimizerStateType: 要加载的优化器 `state_dict`。
options (StateDictOptions) – 控制如何加载模型 `state_dict` 和优化器 `state_dict` 的选项。有关详细信息,请参阅 `StateDictOptions`。
- 返回
missing_keys 是一个包含模型 `state_dict` 缺失键的字符串列表。
unexpected_keys 是一个包含模型 `state_dict` 意外键的字符串列表。
- 返回类型
NamedTuple
,包含missing_keys
和unexpected_keys
字段。
- torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source]#
加载模型 `state_dict`。
`get_model_state_dict` 的对应函数,用于将 `state_dict` 设置到模型。有关详细用法,请参阅 `set_state_dict`。
- 参数
model (nn.Module) – 要处理的模型 `nn.Module`。
model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): 要加载的模型 `state_dict`。如果 `model_state_dict` 的键是 `nn.Module`,则键是 `model` 的一个子模块,值应该是该子模块的 `state_dict`。加载 `state_dict` 时,子模块的路径会附加到 `state_dict`。
options (StateDictOptions) – 控制如何加载模型 `state_dict` 和优化器 `state_dict` 的选项。有关详细信息,请参阅 `StateDictOptions`。
- 返回
missing_keys 是一个包含缺失键的字符串列表。
unexpected_keys 是一个包含意外键的字符串列表。
- 返回类型
NamedTuple
,包含missing_keys
和unexpected_keys
字段。
- torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)[source]#
加载优化器 `state_dict`。
`get_optimizer_state_dict` 的对应函数,用于将 `state_dict` 设置到优化器。有关详细用法,请参阅 `set_state_dict`。
- 警告:`set_optimizer_state_dict` 只能在调用优化器上的 `backward()` 之前或
`step()` 之后调用。否则,优化器状态将无法正确初始化。
- 参数
model (nn.Module) – 要处理的模型 `nn.Module`。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化 `model` 的优化器。
optim_state_dict (OptimizerStateType) – OptimizerStateType: 要加载的优化器 `state_dict`。
options (StateDictOptions) – 控制如何加载模型 `state_dict` 和优化器 `state_dict` 的选项。有关详细信息,请参阅 `StateDictOptions`。
- 返回
无
- 返回类型
无
- class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False, dsd_fqn_modifiers='_fqn_modifiers')[source]#
此数据类指定 `get_state_dict`/`set_state_dict` 如何工作。
`full_state_dict`: 如果设置为 True,则将收集返回的 `state_dict` 中的所有张量。返回的 `state_dict` 中将不包含 ShardedTensor 和 DTensor。
`cpu_offload`: 将所有张量卸载到 CPU。为了防止 CPU OOM,如果 `full_state_dict` 也为 True,则只有 rank0 会收到 `state_dict`,其他所有 rank 都会收到空的 `state_dict`。
`ignore_frozen_params`: 如果值为 True,则返回的 `state_dict` 将不包含任何冻结的参数(`requires_grad` 为 False)。默认值为 False。
`keep_submodule_prefixes` (已弃用): 当 `submodules` 不为 None 时,此选项指示是否保留 `state_dict` 键中的子模块前缀。例如,如果子模块是 `module.pretrain`,并且参数的完整 FQN 是 `pretrain.layer1.weight`。当此选项为 True 时,返回的 `state_dict` 中的参数键将是 `pretrain.layer1.weight`。如果选项为 False,键将是 `layer1.weight`。请注意,如果 `keep_submodule_prefixes` 为 False,可能会出现冲突的 FQN,因此 `submodules` 中应该只有一个子模块。
`strict`: `set_state_dict` 调用 `model.load_state_dict()` 时的 `strict` 选项。
- `broadcast_from_rank0`: 当选项为 True 时,rank0 应该接收一个
完整的 `state_dict`,并将 `state_dict`/`optim_state_dict` 中的张量逐个广播到其他 rank。其他 rank 将接收张量并根据模型和优化器中的本地分片进行分片。使用此选项时,必须将 `full_state_dict` 设置为 True。此选项当前仅支持 DTensor,不支持旧版 ShardedTensor。
对于习惯使用 `torch.save` 格式并共享模型的用户,提供了以下方法,它们提供离线工具来在格式之间进行转换。
- torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[source]#
给定一个包含 DCP checkpoint 的目录,此函数将其转换为 Torch save 文件。
- 参数
警告
为避免 OOM,建议仅在一个 rank 上运行此函数。
- torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[source]#
给定一个 Torch save 文件的位置,将其转换为 DCP checkpoint。
- 参数
警告
为避免 OOM,建议仅在一个 rank 上运行此函数。
以下类也可以用于从 torch.save 格式在线加载和重分片模型。
- class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[source]#
`StorageReader` 用于读取 Torch Save 文件。此 reader 将在协调器 rank 上读取整个 checkpoint,然后将每个张量广播并分片到所有 rank。
. 注意: intended to be used with DynamicMetaLoadPlanner
警告
当前实现仅支持加载张量。
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )
- class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source]#
`DefaultLoadPlanner` 的扩展,它根据传入的 `state_dict` 创建一个新的 Metadata 对象,避免了从磁盘读取元数据的需要。这对于读取没有元数据文件的格式(如 Torch Save 文件)非常有用。
. 注意: intended to be used with BroadcastingTorchSaveReader
警告
当前实现仅支持加载张量。
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )
为提高生产环境的可观测性,提供以下实验性接口。