评价此页

分布式检查点 (DCP) 入门#

创建日期:2023 年 10 月 2 日 | 最后更新:2025 年 7 月 10 日 | 最后验证:2024 年 11 月 5 日

作者Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang, Lucas Pasqualin

注意

editgithub 上查看并编辑本教程。

先决条件

在分布式训练期间对 AI 模型进行检查点保存可能具有挑战性,因为参数和梯度分布在不同的训练器上,且当您恢复训练时,可用的训练器数量可能会发生变化。PyTorch 分布式检查点 (DCP) 可以帮助简化此过程。

在本教程中,我们将展示如何配合一个简单的 FSDP 封装模型使用 DCP API。

DCP 的工作原理#

torch.distributed.checkpoint() 能够并行地从多个进程(ranks)保存和加载模型。您可以使用此模块在任意数量的进程上并行保存,然后在加载时跨不同的集群拓扑结构重新分片。

此外,通过使用 torch.distributed.checkpoint.state_dict() 中的模块,DCP 提供了对在分布式环境中优雅地处理 state_dict 生成和加载的支持。这包括管理模型和优化器之间的全限定名 (FQN) 映射,以及为 PyTorch 提供的并行机制设置默认参数。

DCP 与 torch.save()torch.load() 在以下几个重要方面有所不同:

  • 它为每个检查点生成多个文件,每个进程至少对应一个文件。

  • 它原地(in-place)操作,意味着模型应先分配其数据,DCP 则使用该存储空间。

  • DCP 对有状态对象(在 torch.distributed.checkpoint.stateful 中正式定义)提供特殊处理,如果定义了 state_dictload_state_dict 方法,它会自动调用这些方法。

注意

本教程中的代码运行在 8-GPU 服务器上,但可以轻松推广到其他环境。

如何使用 DCP#

这里我们使用一个用 FSDP 封装的玩具模型进行演示。同样地,这些 API 和逻辑可以应用于更大模型的检查点保存。

保存#

现在,让我们创建一个玩具模块,用 FSDP 封装它,输入一些虚拟数据,然后保存它。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = fully_shard(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    optimizer.zero_grad()
    model(torch.rand(8, 16, device="cuda")).sum().backward()
    optimizer.step()

    state_dict = { "app": AppState(model, optimizer) }
    dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

请查看 checkpoint 目录。您应该会看到与下述文件数量对应的检查点文件。例如,如果您有 8 个设备,则应该看到 8 个文件。

Distributed Checkpoint

加载#

保存后,让我们创建相同的 FSDP 封装模型,并将存储中保存的 state dict 加载到模型中。您可以在相同的世界大小(world size)或不同的世界大小下加载。

请注意,您必须在加载前调用 model.state_dict() 并将其传递给 DCP 的 load_state_dict() API。这与 torch.load() 有本质区别,因为 torch.load() 在加载前仅需要检查点路径。我们需要在加载前获取 state_dict 的原因是:

  • DCP 使用来自模型 state_dict 的预分配存储空间从检查点目录加载数据。在加载过程中,传入的 state_dict 将会被原地更新。

  • DCP 在加载前需要来自模型的分片信息以支持重新分片。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import fully_shard

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_load_example(rank, world_size):
    print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = fully_shard(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    state_dict = { "app": AppState(model, optimizer)}
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_load_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

如果您想将保存的检查点加载到非分布式设置下的非 FSDP 模型中(例如用于推理),也可以使用 DCP。默认情况下,DCP 以单程序多数据 (SPMD) 风格保存和加载分布式 state_dict。但是,如果没有初始化进程组,DCP 会推断目的是以“非分布式”风格(即完全在当前进程中)进行保存或加载。

注意

针对多程序多数据 (MPMD) 的分布式检查点支持仍在开发中。

import os

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn


CHECKPOINT_DIR = "checkpoint"


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = ToyModel()
    state_dict = {
        "model": model.state_dict(),
    }

    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )
    model.load_state_dict(state_dict["model"])

if __name__ == "__main__":
    print(f"Running basic DCP checkpoint loading example.")
    run_checkpoint_load_example()

格式#

尚未提到的一个缺点是,DCP 保存检查点的格式与使用 torch.save 生成的格式本质上不同。由于这在用户希望与习惯使用 torch.save 格式的用户共享模型,或者仅仅是想增加应用程序的格式灵活性时可能会成为问题,因此我们提供了 torch.distributed.checkpoint.format_utils 模块。

为方便用户,提供了一个命令行工具,其格式如下:

python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>

在上述命令中,modetorch_to_dcpdcp_to_torch 之一。

或者,对于希望直接转换检查点的用户,也提供了相应的方法。

import os

import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp

CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"

# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)

# converts the torch.save model back to DCP
torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")

结论#

总之,我们已经学习了如何使用 DCP 的 save()load() API,以及它们与 torch.save()torch.load() 的区别。此外,我们还学习了如何使用 get_state_dict()set_state_dict() 来自动管理 state dict 生成和加载过程中特定于并行机制的 FQN 和默认设置。

欲了解更多信息,请参阅以下内容: