序列化语义#
创建日期:2017年2月26日 | 最后更新日期:2025年6月23日
本文档描述了如何在Python中保存和加载PyTorch张量和模块状态,以及如何序列化Python模块以便在C++中加载。
目录
保存和加载张量#
torch.save() 和 torch.load() 可以方便地保存和加载张量。
>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])
按照惯例,PyTorch 文件通常使用 '.pt' 或 '.pth' 扩展名。
torch.save() 和 torch.load() 默认使用 Python 的 pickle,因此您也可以将多个张量作为 Python 对象(如元组、列表和字典)的一部分进行保存。
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
包含 PyTorch 张量的自定义数据结构也可以保存,前提是该数据结构是可 pickle 的。
保存和加载张量会保留视图#
保存张量会保留它们的视图关系。
>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
在后台,这些张量共享相同的“存储”。有关视图和存储的更多信息,请参阅 张量视图。
当 PyTorch 保存张量时,它会单独保存它们的存储对象和张量元数据。这是可能在未来更改的实现细节,但它通常可以节省空间,并使 PyTorch 能够轻松地重建加载张量之间的视图关系。例如,在上面的代码片段中,只有一个存储被写入 'tensors.pt'。
然而,在某些情况下,保存当前的存储对象可能是不必要的,并会导致生成过大的文件。在下面的代码片段中,为与 large 共享的存储(包含 999 个元素)创建的文件比 small 张量(仅包含 5 个元素)要大得多。
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999
与将 small 张量中的五个元素保存到 'small.pt' 不同,这里保存和加载的是 small 张量所共享的、包含 999 个元素的存储。
当保存的张量中的元素数量少于其存储对象中的元素数量时,可以通过先克隆张量来减小保存文件的大小。克隆张量会创建一个新的张量,它具有一个新的存储对象,仅包含张量中的值。
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt') # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5
但是,由于克隆的张量彼此独立,因此它们不具有原始张量所具有的任何视图关系。如果保存的张量小于其存储对象,并且文件大小和视图关系都很重要,那么在保存之前,必须小心地构造新的张量,以最小化其存储对象的大小,同时仍保持所需的视图关系。
保存和加载 torch.nn.Modules#
另请参阅:教程:保存和加载模型
在 PyTorch 中,模块的状态通常使用“状态字典”进行序列化。模块的状态字典包含其所有参数和持久缓冲区。
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))]
>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
('bias', tensor([0., 0., 0.])),
('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))])
为保证兼容性,建议不要直接保存模块,而是只保存其状态字典。Python 模块甚至有一个函数 load_state_dict(),用于从状态字典恢复其状态。
>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>
请注意,状态字典首先使用 torch.load() 从文件中加载,然后使用 load_state_dict() 恢复状态。
即使是自定义模块和包含其他模块的模块,也都有状态字典,并可以使用此模式。
# A module with two linear layers
>>> class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
[-0.3289, 0.2827, 0.4588, 0.2031]])),
('l0.bias', tensor([ 0.0300, -0.1316])),
('l1.weight', tensor([[0.6533, 0.3413]])),
('l1.bias', tensor([-0.1112]))])
>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
序列化文件格式 for torch.save#
自 PyTorch 1.6.0 起,torch.save 默认返回一个未压缩的 ZIP64 存档,除非用户将 _use_new_zipfile_serialization 设置为 False。
在此存档中,文件按以下顺序排列:
checkpoint.pth
├── data.pkl
├── byteorder # added in PyTorch 2.1.0
├── data/
│ ├── 0
│ ├── 1
│ ├── 2
│ └── …
└── version
- 条目如下:
data.pkl是对传递给torch.save的对象进行 pickle 的结果,但不包含其中包含的torch.Storage对象。byteorder包含一个字符串,其中是保存时sys.byteorder的值(“little” 或 “big”)。data/包含对象中的所有存储,每个存储是一个单独的文件。version包含保存时的版本号,可以在加载时使用。
保存时,PyTorch 将确保每个文件的本地文件头都会填充到 64 字节的倍数偏移量,从而确保每个文件的偏移量都是 64 字节对齐的。
注意
某些设备(如 XLA)上的张量被序列化为 pickled numpy 数组。因此,它们的存储不会被序列化。在这种情况下,检查点中可能不存在 data/ 目录。
布局控制#
在 torch.load() 中的 mmap 参数允许对张量存储进行延迟加载。
此外,还有一些高级功能允许对 torch.save 检查点进行更细粒度的控制和操作。
- 使用
torch.serialization.skip_data上下文管理器可以: 使用
torch.save保存一个包含数据字节预留空间的检查点,以便之后写入。使用
torch.load加载一个检查点,并稍后填充张量的数据字节。
要检查 torch.save 检查点中的张量元数据而不分配存储数据内存,请在 FakeTensorMode 上下文管理器中使用 torch.load。除了跳过加载存储数据(类似于上面的 skip_data)之外,它还会将存储标记上其在检查点内的偏移量,从而可以直接操作检查点。
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
m = nn.Linear(10, 10)
torch.save(m.state_dict(), "checkpoint.pt")
with FakeTensorMode() as mode:
fake_sd = torch.load("checkpoint.pt")
for k, v in fake_sd.items():
print(f"key={k}, dtype={v.dtype}, shape={v.shape}, stride={v.stride()}, storage_offset={v.storage_offset()}")
# offset of the storage in the checkpoint
print(f"key={k}, checkpoint_offset={v.untyped_storage()._checkpoint_offset}")
有关更多信息,此教程提供了使用这些功能操作检查点的综合示例。
torch.load with weights_only=True#
从 2.6 版本开始,如果未传递 pickle_module 参数,torch.load 将使用 weights_only=True。
如 torch.load() 的文档中所述,weights_only=True 将 torch.load 中使用的反 pickle 模块限制为仅执行 torch.Tensors 的 state_dicts 以及其他一些基本类型所需的函数/类。此外,与 pickle 模块提供的默认 Unpickler 不同,weights_only Unpickler 不允许在反 pickling 过程中动态导入任何内容。
如上所述,使用 torch.save 保存模块的 state_dict 是最佳实践。如果加载包含 nn.Module 的旧检查点,我们建议使用 weights_only=False。加载包含张量子类(tensor subclasses)的检查点时,很可能会有需要添加到白名单的函数/类,有关详细信息,请参阅下文。
如果 weights_only Unpickler 遇到一个默认未被白名单的函数或类,您应该会看到类似以下的错误消息:
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
2. Alternatively, to load with `weights_only=True` please check the recommended
steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
`torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
if you trust this class/function.
请按照错误消息中的步骤操作,并仅在您信任这些函数或类时才将其添加到白名单。
要获取检查点中尚未被白名单的所有全局变量(函数/类),您可以使用 torch.serialization.get_unsafe_globals_in_checkpoint(),它将返回一个字符串列表,格式为 {__module__}.{__name__}。如果您信任这些函数/类,可以导入它们并通过 torch.serialization.add_safe_globals() 或使用 torch.serialization.safe_globals 上下文管理器将它们添加到白名单。
要访问用户白名单的函数/类列表,可以使用 torch.serialization.get_safe_globals(),要清除当前列表,请参阅 torch.serialization.clear_safe_globals()。
解决 weights_only 问题#
获取不安全的全局变量#
需要注意的是,torch.serialization.get_unsafe_globals_in_checkpoint() 会对检查点进行静态分析,某些类型可能在反 pickling 过程中动态构建,因此不会被 torch.serialization.get_unsafe_globals_in_checkpoint() 报告。其中一个例子是 numpy 中的 dtypes。在 numpy < 1.25 中,在将 torch.serialization.get_unsafe_globals_in_checkpoint() 报告的所有函数/类添加到白名单后,您可能会看到类似以下的错误:
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>
这可以通过 {add_}safe_globals([type(np.dtype(np.float32))]) 添加到白名单。
在 numpy >=1.25 中,您会看到:
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>
这可以通过 {add_}safe_globals([np.dtypes.Float32DType]) 添加到白名单。
实用函数#
以下实用函数与序列化相关:
- torch.serialization.register_package(priority, tagger, deserializer)[source]#
注册用于为具有关联优先级的存储对象进行标记和反序列化的可调用对象。标记在保存时将设备与存储对象关联,而反序列化在加载时将存储对象移动到合适的设备。
tagger和deserializer将按其priority指定的顺序运行,直到 tagger/deserializer 返回一个非 None 的值。要覆盖全局注册表中某个设备的反序列化行为,可以注册一个优先级高于现有 tagger 的 tagger。
此函数也可用于为新设备注册 tagger 和 deserializer。
- 参数
priority (int) – 指示与 tagger 和 deserializer 关联的优先级,值越小表示优先级越高。
tagger (Callable[[Union[Storage, TypedStorage, UntypedStorage]], Optional[str]]) – 接受存储对象并返回其标记的设备(字符串)或 None 的可调用对象。
deserializer (Callable[[Union[Storage, TypedStorage, UntypedStorage], str], Optional[Union[Storage, TypedStorage, UntypedStorage]]]) – 接受存储对象和设备字符串,并返回合适设备上的存储对象或 None 的可调用对象。
- 返回
无
示例
>>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- torch.serialization.get_crc32_options()[source]#
获取
torch.save()是否为每个记录计算并写入 crc32。默认为
True。- 返回类型
- torch.serialization.set_crc32_options(compute_crc32)[source]#
设置
torch.save()是否为每个记录计算并写入 crc32。注意
将其设置为
False可能会导致torch.save输出的解压失败或因 CRC32 损坏而发出警告。但是torch.load仍然可以加载文件。- 参数
compute_crc32 (bool) – 设置 crc32 计算标志
- torch.serialization.get_default_load_endianness()[source]#
获取加载文件的后备字节序。
如果已保存的检查点中不存在字节序标记,则使用此字节序作为后备。默认情况下,它是“本地”(native)字节序。
- 返回
Optional[LoadEndianness]
- 返回类型
default_load_endian
- torch.serialization.set_default_load_endianness(endianness)[source]#
设置加载文件的后备字节序。
如果已保存的检查点中不存在字节序标记,则使用此字节序作为后备。默认情况下,它是“本地”(native)字节序。
- 参数
endianness – 新的后备字节序
- torch.serialization.get_default_mmap_options()[source]#
获取
torch.load()和mmap=True的默认 mmap 选项。默认为
mmap.MAP_PRIVATE。- 返回
int
- 返回类型
default_mmap_options
- torch.serialization.set_default_mmap_options(flags)[source]#
上下文管理器或函数,用于为
torch.load()和mmap=True设置默认 mmap 选项为 flags。目前,只支持
mmap.MAP_PRIVATE或mmap.MAP_SHARED。如果您需要添加其他选项,请提交一个 issue。注意
此功能目前不支持 Windows。
- 参数
flags (int) –
mmap.MAP_PRIVATE或mmap.MAP_SHARED
- torch.serialization.add_safe_globals(safe_globals)[source]#
将给定的全局变量标记为
weights_only加载是安全的。例如,添加到此列表中的函数可以在反 pickling 过程中被调用,类可以被实例化并设置状态。列表中的每个项可以是函数/类,或者是一个元组,形式为 (函数/类, 字符串),其中字符串是函数/类的完整路径。
在序列化格式中,每个函数都用其完整路径
{__module__}.{__qualname__}来标识。调用此 API 时,您可以提供应与检查点中的路径匹配的完整路径,否则将使用默认的{fn.__module__}.{fn.__qualname__}。- 参数
safe_globals (List[Union[Callable, Tuple[Callable, str]]) – 要标记为安全的全局变量列表
示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... torch.serialization.add_safe_globals([MyTensor]) ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]])
- torch.serialization.get_unsafe_globals_in_checkpoint(f)[source]#
返回
torch.save对象中不安全(不适用于weights_only加载)的函数/类的字符串列表。对于给定的函数或类
f,相应的字符串格式为{f.__module__}.{f.__name__}。此函数将返回检查点中所有不属于
weights_only安全集(通过add_safe_globals()、safe_globals上下文或torch默认白名单)的全局变量。注意
此函数将静态地反汇编检查点中的 pickle 文件。这意味着任何在反 pickling 过程中动态推送到堆栈的类都不会包含在输出中。
- class torch.serialization.safe_globals(safe_globals)[source]#
上下文管理器,用于将某些全局变量添加为
weights_only加载的安全项。示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... with torch.serialization.safe_globals([MyTensor]): ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]]) >>> assert torch.serialization.get_safe_globals() == []
- class torch.serialization.skip_data(materialize_fake_tensors=False)[source]#
上下文管理器,用于跳过
torch.save/torch.load调用中的存储字节的写入/读取。对于保存路径,存储仍会被保存,但其字节通常会写入的空间将为空。然后可以在单独的传递中填充存储字节。
对于加载路径,张量将根据检查点加载,但其存储不会填充数据。
警告
skip_data上下文管理器是一个早期原型,可能会发生更改。- 参数
materialize_fake_tensors (bool) – 保存时是否具体化 FakeTensors。这对加载路径是无操作。
示例
>>> import tempfile >>> t = torch.randn(2, 3) >>> with tempfile.NamedTemporaryFile() as f: ... with torch.serialization.skip_data(): ... torch.save(t, f.name) ... torch.load(f.name, weights_only=True) tensor([[0., 0., 0.], [0., 0., 0.]])
配置#
torch.utils.serialization.config 提供了一个全局配置,可以控制 torch.save 和 torch.load 的行为。
torch.utils.serialization.config.save 包含控制 torch.save 行为的选项。
compute_crc32: 是否计算并写入 zip 文件校验和 (默认:True)。请参阅set_crc32_options()。
use_pinned_memory_for_d2h: 对于传递到torch.save时位于加速器上的存储,是否在torch.save中将存储移动到 CPU 的固定内存或可分页内存。(默认:False(即可分页))
storage_alignment: 在torch.save期间,检查点中存储的对齐字节数。(默认64)
torch.utils.serialization.config.load 包含控制 torch.load 行为的选项。
mmap: 请参阅torch.load()中mmap参数的文档。此配置将设置torch.load的mmap行为,如果它没有被显式传递给torch.load调用的话 (默认:False)。
endianness: 请参阅set_default_load_endianness()。(默认:torch.serialization.LoadEndianness.NATIVE)
mmap_flags: 请参阅set_default_mmap_options。(默认:MAP_PRIVATE)
calculate_storage_offsets: 如果此配置设置为True,则在使用torch.load(mmap=True)时,将计算存储的偏移量,而不是通过随机读取来获取。这可以最大限度地减少随机读取,当文件通过网络加载时可能会很有用。(默认:False)