分布式检查点¶
PyTorch/XLA SPMD 通过专用的 Planner
实例与 torch.distributed.checkpoint 库兼容。用户可以通过这个通用接口同步保存和加载检查点。
SPMDSavePlanner 和 SPMDLoadPlanner(源代码)类使 save
和 load
函数可以直接在 XLAShardedTensor
的分片上操作,从而在 SPMD 训练中实现分布式检查点的所有优势。
以下是同步分布式检查点 API 的演示
import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc
# Saving a state_dict
state_dict = {
"model": model.state_dict(),
"optim": optim.state_dict(),
}
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=xc.SPMDSavePlanner(),
)
...
# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
"model": model.state_dict(),
}
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])
CheckpointManager¶
实验性的 CheckpointManager 接口在 torch.distributed.checkpoint
函数之上提供了一个更高级别的 API,以实现几个关键功能:
受管检查点:由
CheckpointManager
捕获的每个检查点都由其捕获的步骤标识。所有跟踪的步骤都可以通过CheckpointManager.all_steps
方法访问,任何跟踪的步骤都可以使用CheckpointManager.restore
进行恢复。异步检查点:通过
CheckpointManager.save_async
API 捕获的检查点会异步写入持久存储,以便在检查点期间不阻塞训练。输入的分片状态字典首先被移到 CPU,然后检查点被分派到一个后台线程。抢占时的自动检查点:在 Cloud TPU 上,可以检测到抢占,并在进程终止前捕获检查点。要使用,请确保您的 TPU 是通过启用了 自动检查点 的 QueuedResource 预配的,并确保在构造 CheckpointManager 时设置了
chkpt_on_preemption
参数(此选项默认启用)。FSSpec 支持:
CheckpointManager
使用 fsspec 存储后端,允许直接向任何 fsspec 兼容的文件系统(包括 GCS)进行检查点。
CheckpointManager 的示例用法如下
from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer
# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)
# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
# Choose the highest step
best_step = max(tracked_steps)
# Before restoring the checkpoint, the optimizer state must be primed
# to allow state to be loaded into it.
prime_optimizer(optim)
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
chkpt_mgr.restore(best_step, state_dict)
model.load_state_dict(state_dict['model'])
optim.load_state_dict(state_dict['optim'])
# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
...
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
if chkpt_mgr.save_async(step, state_dict):
print(f'Checkpoint taken at step {step}')
恢复优化器状态¶
在分布式检查点中,状态字典是就地加载的,并且只加载检查点所需的片段。由于优化器状态是惰性创建的,因此状态在第一次调用 optimizer.step
之前不存在,并且尝试加载未初始化的优化器将失败。
为此提供了实用方法 prime_optimizer
:它通过将所有梯度设置为零并调用 optimizer.step
来运行一个假的训练步骤。这是一个破坏性方法,会触及模型参数和优化器状态,因此应仅在恢复之前调用。
进程组¶
要使用 torch.distributed
API(如分布式检查点),需要一个进程组。在 SPMD 模式下,由于编译器负责所有集合通信,因此不支持 xla
后端。
相反,必须使用 CPU 进程组,例如 gloo
。在 TPUs 上,xla://
init_method 仍然支持用于发现主 IP、全局世界大小和主机等级。初始化示例如下:
import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr
xr.use_spmd()
# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')