使用分布式检查点 (DCP) 进行异步保存#
创建日期:2024年7月22日 | 最后更新:2026年2月3日 | 最后验证:2024年11月5日
作者: Lucas Pasqualin, Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang, Yunsheng Ni
检查点保存通常是分布式训练工作负载中的关键路径瓶颈,随着模型和集群规模的扩大,其代价也越来越高。抵消这一开销的一个极佳策略是进行并行、异步的检查点保存。下面,我们将扩展分布式检查点入门教程中的保存示例,展示如何通过 torch.distributed.checkpoint.async_save 轻松实现这一功能。
如何使用 DCP 并行生成检查点
优化性能的有效策略
PyTorch v2.4.0 或更高版本
异步检查点概述#
在开始使用异步检查点之前,了解其与同步检查点的差异和局限性非常重要。具体如下:
- 内存需求 - 异步检查点的工作原理是首先将模型复制到内部 CPU 缓冲区中。
这很有帮助,因为它能确保在检查点保存过程中模型和优化器权重不会发生变化,但它会将 CPU 内存占用增加
每个 rank 的检查点大小 X rank 总数。此外,用户应注意系统内存的限制。具体而言,固定内存(pinned memory)涉及页面锁定内存的使用,与可分页内存相比,前者可能较为稀缺。
- 检查点管理 - 由于检查点保存是异步的,用户需要自行管理并发运行的检查点。
通常情况下,用户可以通过处理
async_save返回的 future 对象来采用自己的管理策略。对于大多数用户,我们建议每次限制为一个异步检查点请求,以避免造成额外的单请求内存压力。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQNs, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
checkpoint_future = None
for step in range(10):
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
# waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
if checkpoint_future is not None:
checkpoint_future.result()
state_dict = { "app": AppState(model, optimizer) }
checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running async checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
通过固定内存(Pinned Memory)获得更高性能#
如果上述优化仍无法满足性能要求,您可以利用针对 GPU 模型的额外优化,即使用固定内存缓冲区进行检查点暂存(staging)。具体来说,该优化解决了异步检查点的主要开销,即内存中向检查点缓冲区的复制过程。通过在检查点请求之间维护一个固定内存缓冲区,用户可以利用直接内存访问来加快此复制速度。
注意
此优化的主要缺点是缓冲区在检查点步骤之间会持续占用空间。如果不使用固定内存优化(如上文所述),任何检查点缓冲区都会在检查点完成时立即释放。而在固定内存实现中,该缓冲区会在步骤之间保持,导致整个应用程序生命周期内持续存在相同的峰值内存压力。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQNs, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
# The storage writer defines our 'staging' strategy, where staging is considered the process of copying
# checkpoints to in-memory buffers. By setting `cached_state_dict=True`, we enable efficient memory copying
# into a persistent buffer with pinned memory enabled.
# Note: It's important that the writer persists in between checkpointing requests, since it maintains the
# pinned memory buffer.
writer = StorageWriter(cache_staged_state_dict=True, path=CHECKPOINT_DIR)
checkpoint_future = None
for step in range(10):
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
state_dict = { "app": AppState(model, optimizer) }
if checkpoint_future is not None:
# waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
checkpoint_future.result()
checkpoint_future = dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
使用 DefaultStager 进行完全异步暂存#
2.9 版本新增: async_stager 参数和 DefaultStager 类于 PyTorch 2.9 中引入。
虽然 async_save 异步处理磁盘写入,但将数据从 GPU 复制到 CPU 的过程(称为“暂存”)通常发生在主线程上。即使使用了固定内存,这种设备到主机 (D2H) 的复制操作也可能阻塞大型模型的训练循环。
为了在计算和检查点保存之间实现最大程度的重叠,我们可以使用 DefaultStager。该组件将状态字典(state dictionary)的创建和 D2H 复制操作卸载到后台线程执行。
时间轴比较
标准 async_save:
[GPU 计算] -> [CPU 复制 (阻塞)] -> [磁盘 写入 (异步)]使用 AsyncStager:
[GPU 计算] || [CPU 复制 (异步)] -> [磁盘 写入 (异步)]
注意
使用 AsyncStager 会引入一个消耗 CPU 资源的后台线程。请确保您的环境中拥有足够的 CPU 核心来处理此任务,且不会影响主训练进程。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.staging import DefaultStager
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQNs, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
checkpoint_future = None
for step in range(10):
print(f"Step {step} starting...")
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
# Critical: We must ensure the previous checkpoint's D2H copy (staging)
# is complete before the optimizer modifies the model parameters.
# Placing this await AFTER the backward pass allows us to overlap
# the D2H copy with the current step's Forward and Backward computation.
if checkpoint_future is not None:
checkpoint_future.staging_completion.result()
optimizer.step()
# waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
if checkpoint_future is not None:
checkpoint_future.upload_completion.result()
state_dict = { "app": AppState(model, optimizer) }
# Pass the DefaultStager to enable fully asynchronous staging.
# This offloads the state_dict creation and GPU-to-CPU copy to a background thread.
# The return object (AsyncSaveResponse) exposes distinct futures for staging and upload.
checkpoint_future = dcp.async_save(
state_dict,
checkpoint_id=f"{CHECKPOINT_DIR}_step{step}",
async_stager=DefaultStager(),
)
# Ensure the last checkpoint completes
if checkpoint_future:
checkpoint_future.upload_completion.result()
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running async checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
结论#
总之,我们学习了如何使用 DCP 的 async_save() API 在训练关键路径之外生成检查点。我们也了解了使用此 API 所引入的额外内存和并发开销,以及利用固定内存进一步提升速度的其他优化方法。