分布式 Checkpoint (DCP) 入门#
创建于:2023年10月02日 | 最后更新:2025年07月10日 | 最后验证:2024年11月05日
作者:Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang, Lucas Pasqualin
注意
在 GitHub 上查看和编辑此教程。
先决条件
在分布式训练期间 checkpoint AI 模型可能具有挑战性,因为参数和梯度分布在多个训练器中,并且在恢复训练时可用训练器的数量可能会发生变化。PyTorch 分布式 Checkpoint (DCP) 可以帮助简化此过程。
在本教程中,我们将展示如何使用 DCP API 和一个简单的 FSDP 包装模型。
DCP 的工作原理#
torch.distributed.checkpoint()
支持从多个 rank 并行保存和加载模型。您可以使用此模块并行地在任意数量的 rank 上进行保存,然后在加载时跨不同的集群拓扑重新分片。
此外,通过使用 torch.distributed.checkpoint.state_dict()
中的模块,DCP 支持在分布式环境中优雅地处理 state_dict
的生成和加载。这包括管理模型和优化器之间的全限定名 (FQN) 映射,以及为 PyTorch 提供的并行性设置默认参数。
DCP 与 torch.save()
和 torch.load()
在几个重要方面有所不同
它为每个 checkpoint 生成多个文件,每个 rank 至少一个。
它原地操作,意味着模型应首先分配其数据,DCP 使用该存储空间。
DCP 提供了对 Stateful 对象的特殊处理(在 torch.distributed.checkpoint.stateful 中正式定义),如果定义了 state_dict 和 load_state_dict 方法,它会自动调用它们。
注意
本教程中的代码运行在一个 8 GPU 服务器上,但可以轻松地推广到其他环境。
如何使用 DCP#
在这里,我们使用一个包装了 FSDP 的玩具模型进行演示。同样,这些 API 和逻辑也可应用于大型模型的 checkpoint。
保存#
现在,让我们创建一个玩具模块,用 FSDP 包装它,用一些模拟输入数据喂给它,然后保存它。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
state_dict = { "app": AppState(model, optimizer) }
dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
请继续检查 checkpoint 目录。您应该看到与下面所示文件数量对应的 checkpoint 文件。例如,如果您有 8 个设备,您应该看到 8 个文件。

加载#
保存后,让我们创建相同的 FSDP 包装模型,并将保存的状态字典从存储加载到模型中。您可以以相同的 world size 或不同的 world size 加载。
请注意,您必须在加载之前调用 model.state_dict()
并将其传递给 DCP 的 load_state_dict()
API。这与 torch.load()
fundamentally 不同,因为 torch.load()
仅需要 checkpoint 的路径来进行加载。我们需要 state_dict
才能加载的原因是:
DCP 使用模型 state_dict 中预先分配的存储来从 checkpoint 目录加载。加载过程中,传入的 state_dict 将被原地更新。
DCP 在加载之前需要模型的 sharding 信息以支持重新分片。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_load_example(rank, world_size):
print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
state_dict = { "app": AppState(model, optimizer)}
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_load_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
如果您想将保存的 checkpoint 加载到一个非 FSDP 包装的模型中,在非分布式设置下,或许是为了推理,您也可以使用 DCP 来实现。默认情况下,DCP 以单程序多数据 (SPMD) 样式保存和加载分布式 state_dict
。但是,如果未初始化进程组,DCP 会推断意图是以“非分布式”样式保存或加载,即完全在当前进程中。
注意
多程序多数据 (MPMD) 的分布式 checkpoint 支持仍在开发中。
import os
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
CHECKPOINT_DIR = "checkpoint"
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def run_checkpoint_load_example():
# create the non FSDP-wrapped toy model
model = ToyModel()
state_dict = {
"model": model.state_dict(),
}
# since no progress group is initialized, DCP will disable any collectives.
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
model.load_state_dict(state_dict["model"])
if __name__ == "__main__":
print(f"Running basic DCP checkpoint loading example.")
run_checkpoint_load_example()
格式#
尚未提到的一个缺点是 DCP 保存的 checkpoint 格式与使用 torch.save 生成的格式本质上是不同的。当用户希望与其他习惯于 torch.save 格式的用户共享模型,或者只是想为他们的应用程序添加格式灵活性时,这可能会成为一个问题。在这种情况下,我们在 torch.distributed.checkpoint.format_utils
中提供了 format_utils
模块。
为了方便用户,提供了一个命令行工具,其格式如下:
python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>
在上面的命令中,mode
是 torch_to_dcp
或 dcp_to_torch
之一。
或者,也为可能希望直接转换 checkpoint 的用户提供了方法。
import os
import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"
# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)
# converts the torch.save model back to DCP
torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")
结论#
总而言之,我们学习了如何使用 DCP 的 save()
和 load()
API,以及它们与 torch.save()
和 torch.load()
的区别。此外,我们还学习了如何使用 get_state_dict()
和 set_state_dict()
来自动管理状态字典生成和加载过程中的特定于并行性的 FQN 和默认值。
更多信息,请参阅以下内容