评价此页

torch.save#

torch.save(obj, f, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=True)[源代码]#

将对象保存到磁盘文件。

另请参阅:保存和加载张量

有关操作 checkpoint 的更高级工具,请参阅布局控制

参数
  • obj (object) – 要保存的对象

  • f (Union[str, PathLike[str], IO[bytes]]) – 一个类文件对象(必须实现 write 和 flush 方法),或一个包含文件名 的字符串或 os.PathLike 对象

  • pickle_module (Any) – 用于腌制元数据和对象的模块

  • pickle_protocol (int) – 可以指定以覆盖默认协议

注意

一个常见的 PyTorch 约定是使用 .pt 文件扩展名来保存张量。

注意

PyTorch 在序列化过程中会保留存储共享。有关更多详细信息,请参阅保存和加载张量会保留视图

注意

PyTorch 的 1.6 版本将 `torch.save` 切换为使用新的基于 zipfile 的文件格式。`torch.load` 仍然能够加载旧格式的文件。如果您出于任何原因希望 `torch.save` 使用旧格式,请传递关键字参数 `_use_new_zipfile_serialization=False`。

示例

>>> # Save to file
>>> x = torch.tensor([0, 1, 2, 3, 4])
>>> torch.save(x, "tensor.pt")
>>> # Save to io.BytesIO buffer
>>> buffer = io.BytesIO()
>>> torch.save(x, buffer)