评价此页

序列化语义#

创建于:2017年2月26日 | 最后更新于:2025年5月19日

本笔记描述了如何在 Python 中保存和加载 PyTorch 张量和模块状态,以及如何序列化 Python 模块以便在 C++ 中加载。

保存和加载张量#

torch.save()torch.load() 允许您轻松保存和加载张量

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

按照惯例,PyTorch 文件通常以“.pt”或“.pth”扩展名写入。

torch.save()torch.load() 默认使用 Python 的 pickle,因此您还可以将多个张量作为元组、列表和字典等 Python 对象的一部分进行保存

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

包含 PyTorch 张量的自定义数据结构如果可 pickle,也可以保存。

保存和加载张量保留视图#

保存张量保留它们的视图关系

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

在幕后,这些张量共享相同的“存储”。有关视图和存储的更多信息,请参阅张量视图

当 PyTorch 保存张量时,它会分别保存它们的存储对象和张量元数据。这是一个实现细节,未来可能会改变,但它通常会节省空间,并允许 PyTorch 轻松重构加载张量之间的视图关系。例如,在上面的代码片段中,只有一个存储写入“tensors.pt”。

然而,在某些情况下,保存当前存储对象可能是不必要的,并且会创建过大的文件。在下面的代码片段中,一个远大于已保存张量的存储被写入文件

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

除了将 small 张量中的五个值保存到“small.pt”之外,还保存并加载了它与 large 共享的存储中的 999 个值。

当保存的张量元素少于其存储对象时,可以通过首先克隆张量来减小保存文件的大小。克隆张量会生成一个新张量,其中包含一个新的存储对象,该对象仅包含张量中的值

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

然而,由于克隆的张量彼此独立,它们不具有原始张量所具有的任何视图关系。如果文件大小和视图关系在保存小于其存储对象的张量时都很重要,则必须注意构造新的张量,以最大限度地减小其存储对象的大小,但在保存之前仍具有所需的视图关系。

保存和加载 torch.nn.Modules#

另请参阅:教程:保存和加载模块

在 PyTorch 中,模块的状态经常使用“状态字典”进行序列化。模块的状态字典包含其所有参数和持久缓冲区

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

出于兼容性原因,建议不要直接保存模块,而是仅保存其状态字典。Python 模块甚至有一个函数,load_state_dict(),用于从状态字典恢复其状态

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>

请注意,状态字典首先使用 torch.load() 从其文件加载,然后使用 load_state_dict() 恢复状态。

即使是自定义模块和包含其他模块的模块也具有状态字典,并且可以使用此模式

# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

torch.save 的序列化文件格式#

自 PyTorch 1.6.0 起,除非用户将 _use_new_zipfile_serialization=False,否则 torch.save 默认为返回未压缩的 ZIP64 存档。

在此存档中,文件按以下顺序排列

checkpoint.pth
├── data.pkl
├── byteorder  # added in PyTorch 2.1.0
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
条目如下
  • data.pkl 是对传递给 torch.save 的对象进行 pickle 的结果,不包括它包含的 torch.Storage 对象

  • byteorder 包含一个字符串,其中包含保存时的 sys.byteorder(“little”或“big”)

  • data/ 包含对象中的所有存储,其中每个存储都是一个单独的文件

  • version 包含保存时的版本号,可在加载时使用

保存时,PyTorch 将确保每个文件的本地文件头填充到 64 字节的倍数偏移量,确保每个文件的偏移量是 64 字节对齐的。

注意

某些设备(例如 XLA)上的张量被序列化为 pickled numpy 数组。因此,它们的存储未序列化。在这些情况下,检查点中可能不存在 data/

布局控制#

torch.load() 中的 mmap 参数允许延迟加载张量存储。

此外,还有一些高级功能允许对 torch.save 检查点进行更细粒度的控制和操作。

torch.serialization.skip_data 上下文管理器启用
  • 使用 torch.save 保存一个检查点,其中包含稍后要写入的数据字节的空白空间。

  • 使用 torch.load 加载检查点,并稍后填充张量的数据字节。

要检查 torch.save 检查点中的张量元数据而不为存储数据分配内存,请在 FakeTensorMode 上下文管理器中使用 torch.load。除了跳过加载存储数据(类似于上面的 skip_data)之外,它还会用检查点内的偏移量标记存储,从而实现直接的检查点操作。

import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

m = nn.Linear(10, 10)
torch.save(m.state_dict(), "checkpoint.pt")

with FakeTensorMode() as mode:
    fake_sd = torch.load("checkpoint.pt")

for k, v in fake_sd.items():
    print(f"key={k}, dtype={v.dtype}, shape={v.shape}, stride={v.stride()}, storage_offset={v.storage_offset()}")
    # offset of the storage in the checkpoint
    print(f"key={k}, checkpoint_offset={v.untyped_storage()._checkpoint_offset}")

有关更多信息,本教程提供了使用这些功能操作检查点的综合示例。

使用 weights_only=Truetorch.load#

从 2.6 版开始,如果未传递 pickle_module 参数,torch.load 将使用 weights_only=True

正如 torch.load() 的文档中所述,weights_only=Truetorch.load 中使用的反 pickle 器限制为仅执行普通 torch.Tensorsstate_dicts 以及其他一些原始类型所需的函数/构建类。此外,与 pickle 模块提供的默认 Unpickler 不同,weights_only Unpickler 不允许在反 pickle 过程中动态导入任何内容。

如上所述,使用 torch.save 时,保存模块的 state_dict 是最佳实践。如果加载包含 nn.Module 的旧检查点,我们建议使用 weights_only=False。加载包含张量子类的检查点时,可能需要将函数/类列入白名单,有关详细信息,请参阅下文。

如果 weights_only Unpickler 在 pickle 文件中遇到默认情况下未列入白名单的函数或类,您应该会看到类似以下的可操作错误

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
    1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
        but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
    2. Alternatively, to load with `weights_only=True` please check the recommended
       steps in the following error message.
       WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
       default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
       `torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
       if you trust this class/function.

请按照错误消息中的步骤操作,仅在信任它们的情况下才将函数或类列入白名单。

要获取检查点中尚未列入白名单的所有全局变量(函数/类),您可以使用 torch.serialization.get_unsafe_globals_in_checkpoint(),它将返回一个字符串列表,形式为 {__module__}.{__name__}。如果您信任这些函数/类,您可以导入它们并根据错误消息通过 torch.serialization.add_safe_globals() 或上下文管理器 torch.serialization.safe_globals 将它们列入白名单。

要访问用户列入白名单的函数/类列表,您可以使用 torch.serialization.get_safe_globals(),要清除当前列表,请参阅 torch.serialization.clear_safe_globals()

weights_only 疑难解答#

获取不安全的全局变量#

一个注意事项是 torch.serialization.get_unsafe_globals_in_checkpoint() 静态分析检查点,某些类型可能在反 pickle 过程中动态构建,因此不会由 torch.serialization.get_unsafe_globals_in_checkpoint() 报告。一个这样的例子是 numpy 中的 dtypes。在 numpy < 1.25 中,在将 torch.serialization.get_unsafe_globals_in_checkpoint() 报告的所有函数/类列入白名单后,您可能会看到如下错误

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>

这可以通过 {add_}safe_globals([type(np.dtype(np.float32))]) 列入白名单。

numpy >=1.25 中,您将看到

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>

这可以通过 {add_}safe_globals([np.dtypes.Float32DType]) 列入白名单。

环境变量#

有两个环境变量会影响 torch.load 的行为。如果无法访问 torch.load 调用站点,这些变量可能会有所帮助。

  • TORCH_FORCE_WEIGHTS_ONLY_LOAD=1 将覆盖所有 torch.load 调用站点以使用 weights_only=True

  • TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 将使 torch.load 调用站点使用 weights_only=False 仅当 weights_only 未作为参数传递时。

序列化 torch.nn.Modules 并在 C++ 中加载它们#

另请参阅:教程:在 C++ 中加载 TorchScript 模型

ScriptModules 可以序列化为 TorchScript 程序,并使用 torch.jit.load() 加载。此序列化编码了模块的所有方法、子模块、参数和属性,并且它允许序列化程序在 C++ 中加载(即没有 Python)。

torch.jit.save()torch.save() 之间的区别可能不立即清楚。torch.save() 使用 pickle 保存 Python 对象。这对于原型设计、研究和训练特别有用。torch.jit.save() 另一方面,将 ScriptModules 序列化为可以在 Python 或 C++ 中加载的格式。这在保存和加载 C++ 模块或使用 C++ 运行在 Python 中训练的模块时很有用,这是部署 PyTorch 模型时的常见做法。

在 Python 中编写脚本、序列化和加载模块

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
                      (l0): RecursiveScriptModule(original_name=Linear)
                      (l1): RecursiveScriptModule(original_name=Linear) )

跟踪模块也可以使用 torch.jit.save() 保存,但需要注意的是,只序列化跟踪的代码路径。以下示例演示了这一点

# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

上述模块有一个 if 语句,它未被跟踪输入触发,因此不属于跟踪模块,也未随之序列化。然而,脚本模块包含 if 语句并随之序列化。有关脚本和跟踪的更多信息,请参阅TorchScript 文档

最后,在 C++ 中加载模块

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

有关如何在 C++ 中使用 PyTorch 模块的详细信息,请参阅PyTorch C++ API 文档

跨 PyTorch 版本保存和加载 ScriptModules#

PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。旧版本的 PyTorch 可能不支持较新的模块,而较新版本可能已删除或修改了旧行为。这些更改在 PyTorch 的发行说明中明确描述,依赖于已更改功能的模块可能需要更新才能继续正常工作。在某些有限情况下(如下详述),PyTorch 将保留序列化 ScriptModules 的历史行为,因此它们不需要更新。

torch.div 执行整数除法#

在 PyTorch 1.5 及更早版本中,torch.div() 在给定两个整数输入时会执行向下取整除法

# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

然而,在 PyTorch 1.7 中,torch.div() 将始终对其输入执行真除法,就像 Python 3 中的除法一样

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

torch.div() 的行为保留在序列化 ScriptModules 中。也就是说,在 1.6 之前的 PyTorch 版本中序列化的 ScriptModules 即使在使用更新版本的 PyTorch 加载时,在给定两个整数输入时仍会看到 torch.div() 执行向下取整除法。然而,使用 torch.div() 并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在更早的 PyTorch 版本中加载,因为那些更早的版本不理解新行为。

torch.full 总是推断浮点 dtype#

在 PyTorch 1.5 及更早版本中,torch.full() 总是返回一个浮点张量,无论给定填充值是什么

# PyTorch 1.5 and earlier
>>> torch.full((3,), 1)  # Note the integer fill value...
tensor([1., 1., 1.])     # ...but float tensor!

然而,在 PyTorch 1.7 中,torch.full() 将从填充值推断返回张量的 dtype

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])

torch.full() 的行为保留在序列化 ScriptModules 中。也就是说,在 1.6 之前的 PyTorch 版本中序列化的 ScriptModules 即使在给定布尔或整数填充值时,默认情况下仍会看到 torch.full 返回浮点张量。然而,使用 torch.full() 并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在更早的 PyTorch 版本中加载,因为那些更早的版本不理解新行为。

实用函数#

以下实用函数与序列化相关

torch.serialization.register_package(priority, tagger, deserializer)[来源]#

注册可调用对象,用于标记和反序列化具有关联优先级的存储对象。标记在保存时将设备与存储对象关联,而反序列化在加载时将存储对象移动到适当的设备。taggerdeserializer 按照其 priority 给出的顺序运行,直到 tagger/deserializer 返回一个非 None 的值。

要覆盖全局注册表中设备的反序列化行为,可以注册一个具有比现有 tagger 更高优先级的 tagger。

此函数还可用于为新设备注册 tagger 和 deserializer。

参数
返回

示例

>>> def ipu_tag(obj):
>>>     if obj.device.type == 'ipu':
>>>         return 'ipu'
>>> def ipu_deserialize(obj, location):
>>>     if location.startswith('ipu'):
>>>         ipu = getattr(torch, "ipu", None)
>>>         assert ipu is not None, "IPU device module is not loaded"
>>>         assert torch.ipu.is_available(), "ipu is not available"
>>>         return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
torch.serialization.get_crc32_options()[来源]#

获取 torch.save() 是否计算并为每个记录写入 crc32。

默认为 True

返回类型

布尔值

torch.serialization.set_crc32_options(compute_crc32)[来源]#

设置 torch.save() 是否计算并为每个记录写入 crc32。

注意

将其设置为 False 可能会导致 torch.save 输出的解压缩由于 CRC32 损坏而失败或发出警告。然而,torch.load 将能够加载文件。

参数

compute_crc32 (bool) – 设置 crc32 计算标志

torch.serialization.get_default_load_endianness()[来源]#

获取加载文件的默认字节顺序

如果保存的检查点中不存在字节顺序标记,则此字节顺序用作备用。默认情况下,它是“本机”字节顺序。

返回

Optional[LoadEndianness]

返回类型

default_load_endian

torch.serialization.set_default_load_endianness(endianness)[来源]#

设置加载文件的默认字节顺序

如果保存的检查点中不存在字节顺序标记,则此字节顺序用作备用。默认情况下,它是“本机”字节顺序。

参数

endianness – 新的默认字节顺序

torch.serialization.get_default_mmap_options()[来源]#

获取 torch.load() 带有 mmap=True 时的默认 mmap 选项。

默认为 mmap.MAP_PRIVATE

返回

int

返回类型

default_mmap_options

torch.serialization.set_default_mmap_options(flags)[来源]#

上下文管理器或函数,用于将 torch.load() 的默认 mmap 选项设置为 flags

目前,仅支持 mmap.MAP_PRIVATEmmap.MAP_SHARED。如果您需要添加其他选项,请提出 issue。

注意

此功能目前不支持 Windows。

参数

flags (int) – mmap.MAP_PRIVATEmmap.MAP_SHARED

torch.serialization.add_safe_globals(safe_globals)[来源]#

将给定的全局变量标记为 weights_only 加载安全。例如,添加到此列表的函数可以在反 pickle 期间调用,类可以实例化并设置状态。

列表中的每个项目可以是函数/类,也可以是 (函数/类, 字符串) 形式的元组,其中字符串是函数/类的完整路径。

在序列化格式中,每个函数都由其完整路径 {__module__}.{__qualname__} 标识。调用此 API 时,您可以提供此完整路径,该路径应与检查点中的路径匹配,否则将使用默认的 {fn.__module__}.{fn.__qualname__}

参数

safe_globals (List[Union[Callable, Tuple[Callable, str]]]) – 要标记为安全的全局变量列表

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     torch.serialization.add_safe_globals([MyTensor])
...     torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
torch.serialization.clear_safe_globals()[来源]#

清除 weights_only 加载安全的全局变量列表。

torch.serialization.get_safe_globals()[来源]#

返回用户添加的、weights_only 加载安全的全局变量列表。

返回类型

list[Union[Callable, tuple[Callable, str]]]

torch.serialization.get_unsafe_globals_in_checkpoint(f)[来源]#

返回 torch.save 对象中不安全的函数/类字符串列表,用于 weights_only

对于给定的函数或类 f,相应的字符串将采用 {f.__module__}.{f.__name__} 形式。

此函数将返回检查点中所有未标记为 weights_only 安全的全局变量(无论是通过 add_safe_globals()safe_globals 上下文管理器,或默认情况下由 torch 列入白名单)。

注意

此函数将静态反汇编检查点中的 pickle 文件。这意味着在反 pickle 期间动态推送到堆栈的任何类都不会包含在输出中。

参数

f (Union[str, PathLike[str], IO[bytes]]) – 文件类对象或包含通过 torch.save 保存的检查点对象的字符串

返回

检查点中未列入 weights_only 白名单的 pickle 全局变量字符串列表。

返回类型

list[str]

class torch.serialization.safe_globals(safe_globals)[来源]#

上下文管理器,将某些全局变量添加为 weights_only 加载的安全变量。

参数

safe_globals (list[Union[Callable, tuple[Callable, str]]]) – 仅权重加载的全局变量列表。

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     with torch.serialization.safe_globals([MyTensor]):
...         torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
>>> assert torch.serialization.get_safe_globals() == []
class torch.serialization.skip_data(materialize_fake_tensors=False)[来源]#

上下文管理器,用于跳过 torch.save / torch.load 调用时写入/读取存储字节。

对于保存路径,存储仍将保存,但通常写入其字节的空间将为空白空间。然后可以在单独的传递中填充存储字节。

对于加载路径,张量将根据检查点加载,但其存储不会填充数据。

警告

skip_data 上下文管理器是一个早期原型,可能会发生变化。

参数

materialize_fake_tensors (bool) – 在保存期间是否具体化 FakeTensor。这对加载路径是空操作。

示例

>>> import tempfile
>>> t = torch.randn(2, 3)
>>> with tempfile.NamedTemporaryFile() as f:
...     with torch.serialization.skip_data():
...         torch.save(t, f.name)
...     torch.load(f.name, weights_only=True)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

配置#

torch.utils.serialization.config 提供了一个全局配置,可以控制 torch.savetorch.load 的行为。

torch.utils.serialization.config.save 包含控制 torch.save 行为的选项。

  • compute_crc32:是否计算并写入 zip 文件校验和(默认值:True)。请参阅 set_crc32_options()

  • use_pinned_memory_for_d2h:对于传递给 torch.save 时位于加速器上的存储,是否在 torch.save 中将存储移动到 CPU 上的固定内存或可分页内存。(默认值:False(即可分页))

  • storage_alignment:在 torch.save 期间检查点中存储的对齐方式,以字节为单位。(默认值 64

torch.utils.serialization.config.load 包含控制 torch.load 行为的选项。

  • mmap:请参阅 torch.load()mmap 参数的文档。如果未显式传递给 torch.load 调用,此配置将设置 torch.loadmmap 行为(默认值:False)。

  • endianness:请参阅 set_default_load_endianness()。(默认值:torch.serialization.LoadEndianness.NATIVE

  • mmap_flags:请参阅 set_default_mmap_options。(默认值:MAP_PRIVATE

  • calculate_storage_offsets:如果此配置设置为 True,则在使用 torch.load(mmap=True) 时,将计算存储的偏移量,而不是通过随机读取。这最大限度地减少了随机读取,这在通过网络加载文件时会很有帮助。(默认值:False