评价此页
torch.savetorch.load 结合 GPUDirect Storage">

(原型)使用 GPUDirect Storage 加速 torch.savetorch.load#

GPUDirect Storage 实现了 GPU 内存和存储之间的直接内存访问传输的直接数据路径,避免了通过 CPU 的短暂缓冲。

在 **2.7** 版本中,我们引入了 torch.cuda.gds 的新原型 API,它们是 cuFile API 的轻量级封装,可与 torch.Tensor 一起使用,以提高 I/O 性能。

在本教程中,我们将演示如何在本地文件系统上使用 torch.cuda.gds API 结合 torch.savetorch.load 生成的检查点。

您将学到什么
  • 了解如何在本地文件系统上使用 torch.cuda.gds API 结合 torch.savetorch.load 生成的检查点

先决条件
  • PyTorch v.2.7.0 或更高版本

  • 必须根据 文档 安装 GPUDirect Storage

  • 确保您正在保存/加载到的文件系统支持 GPUDirect Storage。

将 GPUDirect Storage 与 torch.savetorch.load 结合使用#

GPUDirect Storage 需要 4KB 的存储对齐。您可以使用 torch.utils.serialization.config.save.storage_alignment 来切换此设置。

import torch
from torch.utils.serialization import config as serialization_config

serialization_config.save.storage_alignment = 4096
该过程涉及的步骤如下:
  • 写入检查点文件,而不写入任何实际数据。这会在磁盘上预留空间。

  • 使用 FakeTensor 读取检查点中与每个张量关联的存储的偏移量。

  • 使用 GDSFile 在这些偏移量处写入相应的数据。

给定一个位于 GPU 上的张量状态字典,可以使用 torch.serialization.skip_data 上下文管理器来保存一个检查点,该检查点包含除存储字节以外的所有相关元数据。对于状态字典中的每个 torch.Storage,将在检查点内为存储字节预留空间。

import torch.nn as nn

m = nn.Linear(5, 10, device='cuda')
sd = m.state_dict()

with torch.serialization.skip_data():
    torch.save(sd, "checkpoint.pt")

我们可以通过在 FakeTensorMode 下加载来获取每个存储在检查点内应写入的偏移量。FakeTensor 是一个具有张量元数据(如大小、步幅、dtype、设备)但没有存储字节的张量。以下代码片段不会具体化任何数据,但会将每个 FakeTensor 标记为在检查点内与该张量对应的偏移量。

如果您在训练过程中持续保存相同的状态字典,则只需获取一次偏移量,并且可以重复使用相同的偏移量。同样,如果一个张量将被重复保存或加载,您可以使用 torch.cuda.gds.gds_register_buffer,它封装了 cuFileBufRegister 以将存储注册为 GDS 缓冲区。

请注意,torch.cuda.gds.GdsFile.save_storage 绑定到同步 cuFileWrite API,因此之后不需要进行同步。

import os
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode() as mode:
    fake_sd = torch.load("checkpoint.pt")

for k, v in fake_sd.items():
    print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")

f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)

for k, v in sd.items():
    offset = fake_sd[k].untyped_storage()._checkpoint_offset
    # save_storage is a wrapper around `cuFileWrite`
    f.save_storage(v.untyped_storage(), offset)

我们通过 torch.load 并进行比较来验证保存的检查点的正确性。

sd_loaded = torch.load("checkpoint.pt")
for k, v in sd_loaded.items():
    assert torch.equal(v, sd[k])

加载流程是反向的:您可以使用 torch.loadtorch.serialization.skip_data 上下文管理器来加载除存储字节外的所有内容。这意味着检查点中的任何张量都将被创建,但它们的存储将是空的(就像张量是通过 torch.empty 创建的一样)。

with torch.serialization.skip_data():
    sd_loaded = torch.load("checkpoint.pt")

我们再次使用 FakeTensorMode 来获取检查点偏移量,并确定加载的检查点与保存的检查点相同。

torch.cuda.gds.GdsFile.save_storage 类似,torch.cuda.gds.GdsFile.load_storage 绑定到同步 cuFileRead API,因此之后不需要进行同步。

for k, v in sd_loaded.items():
    assert not torch.equal(v, sd[k])
    offset = fake_sd[k].untyped_storage()._checkpoint_offset
    # load_storage is a wrapper around `cuFileRead`
    f.load_storage(v.untyped_storage(), offset)

for k, v in sd_loaded.items():
    assert torch.equal(v, sd[k])

del f

结论#

在本教程中,我们演示了如何在本地文件系统上使用原型 torch.cuda.gds API 结合 torch.savetorch.load。如果您有任何反馈,请在 PyTorch GitHub 仓库中提交一个 issue。