评价此页

分布式 Checkpoint (DCP) 入门#

创建日期: 2023 年 10 月 02 日 | 最后更新: 2025 年 07 月 10 日 | 最后验证: 2024 年 11 月 05 日

作者: Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang, Lucas Pasqualin

注意

editgithub 上查看和编辑此教程。

先决条件

在分布式训练过程中保存 AI 模型检查点可能具有挑战性,因为参数和梯度分布在不同的训练器之间,并且可用训练器的数量可能会在您恢复训练时发生变化。PyTorch分布式 Checkpoint (DCP) 可以帮助简化此过程。

在本教程中,我们将展示如何使用 DCP API 和一个简单的 FSDP 包装模型。

DCP 如何工作#

torch.distributed.checkpoint() 支持从多个进程并行保存和加载模型。您可以使用此模块在任意数量的进程上并行保存,然后在加载时跨不同的集群拓扑进行重新分片。

此外,通过使用 torch.distributed.checkpoint.state_dict() 中的模块,DCP 支持在分布式环境中优雅地处理 state_dict 的生成和加载。这包括管理模型和优化器之间的完全限定名称 (FQN) 映射,以及为 PyTorch 提供的并行性设置默认参数。

DCP 在几个重要方面与 torch.save()torch.load() 不同:

  • 它为每个检查点生成多个文件,每个进程至少有一个文件。

  • 它在原地操作,意味着模型应首先分配其数据,DCP 使用该存储空间。

  • DCP 为有状态对象 (在 torch.distributed.checkpoint.stateful 中正式定义) 提供特殊处理,如果定义了 state_dictload_state_dict 方法,则会自动调用它们。

注意

本教程中的代码在一个 8-GPU 服务器上运行,但可以轻松推广到其他环境。

如何使用 DCP#

这里我们使用一个 toy 模型包装 FSDP 进行演示。同样,这些 API 和逻辑也可应用于更大的模型以进行检查点保存。

保存#

现在,让我们创建一个 toy 模块,用 FSDP 包装它,用一些 dummy 输入数据喂给它,然后保存它。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = fully_shard(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    optimizer.zero_grad()
    model(torch.rand(8, 16, device="cuda")).sum().backward()
    optimizer.step()

    state_dict = { "app": AppState(model, optimizer) }
    dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

请继续检查 checkpoint 目录。您应该看到与下面显示的相同数量的检查点文件。例如,如果您有 8 个设备,您应该看到 8 个文件。

Distributed Checkpoint

加载#

保存后,让我们创建相同的 FSDP 包装模型,并将保存的 state dict 从存储加载到模型中。您可以以相同的 world size 或不同的 world size 加载。

请注意,在加载之前,您需要调用 model.state_dict() 并将其传递给 DCP 的 load_state_dict() API。这与 torch.load() fundamentally 不同,因为 torch.load() 仅需要加载前的检查点路径。我们需要在加载前使用 state_dict 的原因如下:

  • DCP 使用从 model state_dict 预分配的存储来从检查点目录加载。在加载过程中,传入的 state_dict 将被原地更新。

  • DCP 在加载前需要模型的 sharding 信息,以支持重新分片。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import fully_shard

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_load_example(rank, world_size):
    print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = fully_shard(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    state_dict = { "app": AppState(model, optimizer)}
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_load_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

如果您想将保存的检查点加载到一个非分布式设置的非 FSDP 包装模型中,例如用于推理,您也可以使用 DCP 来完成。默认情况下,DCP 以单程序多数据 (SPMD) 风格保存和加载分布式 state_dict。但是,如果未初始化进程组,DCP 会推断意图是以“非分布式”风格进行保存或加载,即完全在当前进程中。

注意

多程序多数据 (MPMD) 的分布式检查点支持仍在开发中。

import os

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn


CHECKPOINT_DIR = "checkpoint"


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = ToyModel()
    state_dict = {
        "model": model.state_dict(),
    }

    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )
    model.load_state_dict(state_dict["model"])

if __name__ == "__main__":
    print(f"Running basic DCP checkpoint loading example.")
    run_checkpoint_load_example()

格式#

还有一个尚未提到的缺点是,DCP 保存的检查点格式与使用 torch.save 生成的格式本质上是不同的。当用户希望与习惯 torch.save 格式的用户共享模型,或者只想为应用程序添加格式灵活性时,这可能会成为一个问题。在这种情况下,我们在 torch.distributed.checkpoint.format_utils 中提供了 format_utils 模块。

为方便用户,提供了一个命令行实用程序,其格式如下:

python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>

在上面的命令中,modetorch_to_dcpdcp_to_torch 之一。

或者,也提供了方法供希望直接转换检查点的用户使用。

import os

import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp

CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"

# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)

# converts the torch.save model back to DCP
torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")

结论#

总而言之,我们学习了如何使用 DCP 的 save()load() API,以及它们与 torch.save()torch.load() 的区别。此外,我们还学习了如何使用 get_state_dict()set_state_dict() 来在 state dict 生成和加载过程中自动管理特定于并行性的 FQN 和默认值。

有关更多信息,请参阅以下内容: