注意
转到末尾下载完整的示例代码。
(原型)使用 GPUDirect Storage 加速 torch.save
和 torch.load
#
GPUDirect Storage 为 GPU 内存和存储之间的直接内存访问(DMA)传输提供了一条直接数据路径,从而避免了通过 CPU 的反弹缓冲区(bounce buffer)。
在 2.7 版本中,我们向 torch.cuda.gds
引入了新的原型 API,它们作为 cuFile APIs 的轻量级包装器,可与 torch.Tensor
一起使用以实现更高的 I/O 性能。
在本教程中,我们将演示如何在本地文件系统上,结合使用 torch.cuda.gds
API 与由 torch.save
和 torch.load
生成的检查点。
了解如何在本地文件系统上,结合使用
torch.cuda.gds
API 与由torch.save
和torch.load
生成的检查点
PyTorch v.2.7.0 或更高版本
必须按照此文档安装 GPUDirect Storage
确保您保存/加载到的文件系统支持 GPUDirect Storage。
结合 torch.save
和 torch.load
使用 GPUDirect Storage#
GPUDirect Storage 需要 4KB 的存储对齐。您可以使用 torch.utils.serialization.config.save.storage_alignment
来切换此设置。
- 该过程涉及的步骤如下:
写入不含任何实际数据的检查点文件。这会在磁盘上预留空间。
使用
FakeTensor
读取检查点中每个张量对应存储的偏移量。使用
GDSFile
在这些偏移量处写入相应的数据。
给定一个位于 GPU 上的张量的状态字典,可以使用 torch.serialization.skip_data
上下文管理器来保存一个包含所有相关元数据但不包含存储字节的检查点。对于状态字典中的每个 torch.Storage
,将在检查点内为存储字节预留空间。
我们可以通过在 FakeTensorMode
下加载来获取每个存储应写入到检查点内的偏移量。FakeTensor 是一个包含张量元数据(如尺寸、步长、数据类型、设备)信息但没有任何存储字节的张量。以下代码片段不会物化任何数据,但会为每个 FakeTensor
标记其在检查点中对应的偏移量。
如果您在训练期间连续保存相同的状态字典,您只需获取一次偏移量,之后便可重复使用相同的偏移量。同样,如果一个张量将被重复保存或加载,您可以使用 torch.cuda.gds.gds_register_buffer
,它包装了 cuFileBufRegister
,将存储注册为 GDS 缓冲区。
请注意,torch.cuda.gds.GdsFile.save_storage
绑定到同步的 cuFileWrite
API,因此之后不需要同步。
我们通过 torch.load
加载并比较来验证保存的检查点的正确性。
加载流程是相反的:您可以使用带有 torch.serialization.skip_data
上下文管理器的 torch.load
来加载除存储字节之外的所有内容。这意味着检查点中的任何张量都将被创建,但它们的存储将是空的(就像通过 torch.empty
创建的张量一样)。
我们再次使用 FakeTensorMode
来获取检查点偏移量,并确定加载的检查点与保存的检查点相同。
与 torch.cuda.gds.GdsFile.save_storage
类似,torch.cuda.gds.GdsFile.load_storage
绑定到同步的 cuFileRead
API,因此之后不需要同步。
结论#
在本教程中,我们演示了如何在本地文件系统上,结合使用原型 torch.cuda.gds
API 与 torch.save
和 torch.load
。如果您有任何反馈,请在 PyTorch GitHub 仓库中提交一个 issue。
# %%%%%%RUNNABLE_CODE_REMOVED%%%%%%