注意
转到末尾 下载完整的示例代码。
(原型)使用 GPUDirect Storage 加速 torch.save 和 torch.load#
GPUDirect Storage 实现了 GPU 内存和存储之间的直接内存访问传输的直接数据路径,避免了通过 CPU 的短暂缓冲。
在 **2.7** 版本中,我们引入了 torch.cuda.gds 的新原型 API,它们是 cuFile API 的轻量级封装,可与 torch.Tensor 一起使用,以提高 I/O 性能。
在本教程中,我们将演示如何在本地文件系统上使用 torch.cuda.gds API 结合 torch.save 和 torch.load 生成的检查点。
了解如何在本地文件系统上使用
torch.cuda.gdsAPI 结合torch.save和torch.load生成的检查点
PyTorch v.2.7.0 或更高版本
必须根据 文档 安装 GPUDirect Storage
确保您正在保存/加载到的文件系统支持 GPUDirect Storage。
将 GPUDirect Storage 与 torch.save 和 torch.load 结合使用#
GPUDirect Storage 需要 4KB 的存储对齐。您可以使用 torch.utils.serialization.config.save.storage_alignment 来切换此设置。
import torch
from torch.utils.serialization import config as serialization_config
serialization_config.save.storage_alignment = 4096
- 该过程涉及的步骤如下:
写入检查点文件,而不写入任何实际数据。这会在磁盘上预留空间。
使用
FakeTensor读取检查点中与每个张量关联的存储的偏移量。使用
GDSFile在这些偏移量处写入相应的数据。
给定一个位于 GPU 上的张量状态字典,可以使用 torch.serialization.skip_data 上下文管理器来保存一个检查点,该检查点包含除存储字节以外的所有相关元数据。对于状态字典中的每个 torch.Storage,将在检查点内为存储字节预留空间。
import torch.nn as nn
m = nn.Linear(5, 10, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data():
torch.save(sd, "checkpoint.pt")
我们可以通过在 FakeTensorMode 下加载来获取每个存储在检查点内应写入的偏移量。FakeTensor 是一个具有张量元数据(如大小、步幅、dtype、设备)但没有存储字节的张量。以下代码片段不会具体化任何数据,但会将每个 FakeTensor 标记为在检查点内与该张量对应的偏移量。
如果您在训练过程中持续保存相同的状态字典,则只需获取一次偏移量,并且可以重复使用相同的偏移量。同样,如果一个张量将被重复保存或加载,您可以使用 torch.cuda.gds.gds_register_buffer,它封装了 cuFileBufRegister 以将存储注册为 GDS 缓冲区。
请注意,torch.cuda.gds.GdsFile.save_storage 绑定到同步 cuFileWrite API,因此之后不需要进行同步。
import os
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode() as mode:
fake_sd = torch.load("checkpoint.pt")
for k, v in fake_sd.items():
print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")
f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)
for k, v in sd.items():
offset = fake_sd[k].untyped_storage()._checkpoint_offset
# save_storage is a wrapper around `cuFileWrite`
f.save_storage(v.untyped_storage(), offset)
我们通过 torch.load 并进行比较来验证保存的检查点的正确性。
sd_loaded = torch.load("checkpoint.pt")
for k, v in sd_loaded.items():
assert torch.equal(v, sd[k])
加载流程是反向的:您可以使用 torch.load 和 torch.serialization.skip_data 上下文管理器来加载除存储字节外的所有内容。这意味着检查点中的任何张量都将被创建,但它们的存储将是空的(就像张量是通过 torch.empty 创建的一样)。
with torch.serialization.skip_data():
sd_loaded = torch.load("checkpoint.pt")
我们再次使用 FakeTensorMode 来获取检查点偏移量,并确定加载的检查点与保存的检查点相同。
与 torch.cuda.gds.GdsFile.save_storage 类似,torch.cuda.gds.GdsFile.load_storage 绑定到同步 cuFileRead API,因此之后不需要进行同步。
for k, v in sd_loaded.items():
assert not torch.equal(v, sd[k])
offset = fake_sd[k].untyped_storage()._checkpoint_offset
# load_storage is a wrapper around `cuFileRead`
f.load_storage(v.untyped_storage(), offset)
for k, v in sd_loaded.items():
assert torch.equal(v, sd[k])
del f
结论#
在本教程中,我们演示了如何在本地文件系统上使用原型 torch.cuda.gds API 结合 torch.save 和 torch.load。如果您有任何反馈,请在 PyTorch GitHub 仓库中提交一个 issue。