序列化语义#
创建于: 2017年2月26日 | 最后更新于: 2025年10月27日
本文档描述了如何在 Python 中保存和加载 PyTorch 张量和模块状态,以及如何序列化 Python 模块以便在 C++ 中加载。
保存和加载张量#
torch.save() 和 torch.load() 可以方便地保存和加载张量。
>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])
按照惯例,PyTorch 文件通常以 ‘.pt’ 或 ‘.pth’ 扩展名保存。
torch.save() 和 torch.load() 默认使用 Python 的 pickle,因此您也可以将多个张量作为 Python 对象(如元组、列表和字典)的一部分来保存。
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
包含 PyTorch 张量的自定义数据结构也可以保存,前提是该数据结构是可 pickle 的。
保存和加载张量会保留视图#
保存张量会保留它们的视图关系。
>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
在底层,这些张量共享相同的“存储”。有关视图和存储的更多信息,请参阅 张量视图。
当 PyTorch 保存张量时,它会分别保存它们的存储对象和张量元数据。这是一个实现细节,未来可能会发生变化,但它通常可以节省空间,并允许 PyTorch 轻松地重建加载张量之间的视图关系。例如,在上方的代码片段中,只有单个存储被写入 'tensors.pt'。
然而,在某些情况下,保存当前的存储对象可能是不必要的,并且会创建过大的文件。在下面的代码片段中,一个比保存的张量大的多的存储被写入了一个文件。
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999
不是只将 small 张量中的五个值保存到 'small.pt',而是保存并加载了它与 large 共享的存储中的 999 个值。
当保存的元素数量少于其存储对象的张量时,可以通过先克隆张量来减小保存文件的大小。克隆张量会生成一个新的张量,其存储对象只包含张量中的值。
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt') # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5
然而,由于克隆的张量是相互独立的,因此它们没有原始张量拥有的任何视图关系。如果保存小于其存储对象的张量时,文件大小和视图关系都很重要,那么在保存之前,必须小心构建新的张量,以最小化其存储对象的大小,同时仍然保持所需的视图关系。
保存和加载 torch.nn.Modules#
另请参阅: 教程:保存和加载模型
在 PyTorch 中,模块的状态通常使用“状态字典”进行序列化。模块的状态字典包含其所有参数和持久化缓冲区。
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))]
>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
('bias', tensor([0., 0., 0.])),
('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))])
为了兼容性,不建议直接保存模块,而是建议只保存其状态字典。Python 模块甚至有一个函数 load_state_dict(),用于从状态字典恢复其状态。
>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>
请注意,状态字典首先使用 torch.load() 从其文件中加载,然后使用 load_state_dict() 恢复状态。
即使是自定义模块和包含其他模块的模块也有状态字典,并且可以使用此模式。
# A module with two linear layers
>>> class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
[-0.3289, 0.2827, 0.4588, 0.2031]])),
('l0.bias', tensor([ 0.0300, -0.1316])),
('l1.weight', tensor([[0.6533, 0.3413]])),
('l1.bias', tensor([-0.1112]))])
>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
`torch.save` 的序列化文件格式#
自 PyTorch 1.6.0 起,torch.save 默认返回一个未压缩的 ZIP64 归档,除非用户设置了 _use_new_zipfile_serialization=False。
在此归档中,文件按以下顺序排列:
checkpoint.pth
├── data.pkl
├── byteorder # added in PyTorch 2.1.0
├── data/
│ ├── 0
│ ├── 1
│ ├── 2
│ └── …
└── version
- 条目如下:
data.pkl是传递给torch.save的对象进行 pickle 后的结果,不包含它所包含的torch.Storage对象。byteorder包含一个字符串,表示保存时sys.byteorder的值(“little” 或 “big”)。data/包含对象中的所有存储,其中每个存储都是一个单独的文件。version包含保存时的版本号,可以在加载时使用。
保存时,PyTorch 会确保每个文件的本地文件头都填充到 64 字节的倍数偏移量,从而确保每个文件的偏移量都是 64 字节对齐的。
注意
某些设备(如 XLA)上的张量被序列化为 pickle 的 numpy 数组。因此,它们的存储不会被序列化。在这种情况下,检查点中可能不存在 data/。
布局控制#
torch.load() 中的 mmap 参数允许惰性加载张量存储。
此外,还有一些高级功能允许更精细地控制和操作 torch.save 检查点。
torch.serialization.skip_data上下文管理器启用了:使用
torch.save保存检查点,其中包含用于稍后写入数据字节的空位。使用
torch.load加载检查点,并稍后填充张量的数据字节。
要在不为存储数据分配内存的情况下检查 torch.save 检查点中的张量元数据,请在 FakeTensorMode 上下文管理器中使用 torch.load。除了像 skip_data 那样跳过加载存储数据之外,它还会将存储标记上其在检查点内的偏移量,从而可以直接操作检查点。
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
m = nn.Linear(10, 10)
torch.save(m.state_dict(), "checkpoint.pt")
with FakeTensorMode() as mode:
fake_sd = torch.load("checkpoint.pt")
for k, v in fake_sd.items():
print(f"key={k}, dtype={v.dtype}, shape={v.shape}, stride={v.stride()}, storage_offset={v.storage_offset()}")
# offset of the storage in the checkpoint
print(f"key={k}, checkpoint_offset={v.untyped_storage()._checkpoint_offset}")
有关更多信息,此教程提供了使用这些功能操作检查点的全面示例。
`torch.load` 配合 `weights_only=True`#
从 2.6 版本开始,如果未传递 pickle_module 参数,torch.load 将使用 weights_only=True。
`weights_only` 安全性#
如 torch.load() 文档中所述,weights_only=True 将 torch.load 中使用的 unpickler 限制为仅执行 torch.Tensors 的 state_dicts 以及一些其他基本类型所需的函数/构建类。此外,与 pickle 模块提供的默认 Unpickler 不同,weights_only Unpickler 不允许在 unpickling 过程中动态导入任何内容。
weights_only=True 缩小了远程代码执行攻击的表面,但存在以下限制:
weights_only=True不能防御拒绝服务攻击。我们试图防止
torch.load(weights_only=True)过程中的内存损坏,但它们仍有可能发生。
请注意,即使 torch.load 本身没有发生内存损坏,加载也可能为下游代码创建意外的对象,这些对象也可能导致内存损坏(例如,用户代码中为稀疏张量创建的索引和值张量可能会读/写越界)。
`weights_only` 允许列表#
如上所述,在使用 torch.save 时,保存模块的 state_dict 是一种最佳实践。如果加载包含 nn.Module 的旧检查点,我们建议使用 weights_only=False。加载包含张量子类(tensor subclasses)的检查点时,很可能需要将某些函数/类添加到允许列表中,有关更多详细信息,请参见下文。
如果 weights_only Unpickler 遇到一个默认未列入允许列表的函数或类,您应该会看到一个可操作的错误,如下所示:
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
2. Alternatively, to load with `weights_only=True` please check the recommended
steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
`torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
if you trust this class/function.
请按照错误消息中的步骤操作,并且只允许您信任的函数或类。
要获取检查点中所有尚未列入允许列表的全局变量(函数/类),您可以使用 torch.serialization.get_unsafe_globals_in_checkpoint(),它将返回一个字符串列表,格式为 {__module__}.{__name__}。如果您信任这些函数/类,您可以导入它们,并根据错误消息通过 torch.serialization.add_safe_globals() 或上下文管理器 torch.serialization.safe_globals 将它们列入允许列表。
要访问用户允许列表函数/类的列表,您可以使用 torch.serialization.get_safe_globals(),要清除当前列表,请参阅 torch.serialization.clear_safe_globals()。
排查 `weights_only` 问题#
获取不安全全局变量#
一个需要注意的是,torch.serialization.get_unsafe_globals_in_checkpoint() 是静态分析检查点,某些类型可能在 unpickling 过程中动态构建,因此不会被 torch.serialization.get_unsafe_globals_in_checkpoint() 报告。其中一个例子是 numpy 的 dtypes。在 numpy < 1.25 中,在允许列表 torch.serialization.get_unsafe_globals_in_checkpoint() 报告的所有函数/类后,您可能会看到类似以下的错误:
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>
这可以通过 {add_}safe_globals([type(np.dtype(np.float32))]) 来添加到允许列表。
在 numpy >=1.25 中,您会看到:
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>
这可以通过 {add_}safe_globals([np.dtypes.Float32DType]) 来添加到允许列表。
环境变量#
有两个环境变量会影响 torch.load 的行为。如果您无法访问 torch.load 的调用点,这些变量会很有用。
TORCH_FORCE_WEIGHTS_ONLY_LOAD=1将覆盖所有torch.load调用点,强制使用weights_only=True。TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1会使torch.load调用点使用weights_only=False,**仅当**weights_only未作为参数传递时。
实用函数#
以下实用函数与序列化相关:
- torch.serialization.register_package(priority, tagger, deserializer)[source]#
注册用于为具有关联优先级的存储对象进行标记和反序列化的可调用对象。标记在保存时将设备与存储对象关联,而反序列化在加载时将存储对象移动到适当的设备。
tagger和deserializer将按照其priority给出的顺序运行,直到 tagger/deserializer 返回一个非 None 的值。要覆盖全局注册表中某个设备的反序列化行为,可以注册一个优先级高于现有 tagger 的 tagger。
此函数还可用于为新设备注册 tagger 和 deserializer。
- 参数:
priority (int) – 指示与 tagger 和 deserializer 关联的优先级,值越低表示优先级越高。
tagger (Callable[[Storage | TypedStorage | UntypedStorage], str | None]) – 接受存储对象并返回其标记设备(字符串形式)或 None 的可调用对象。
deserializer (Callable[[Storage | TypedStorage | UntypedStorage, str], Storage | TypedStorage | UntypedStorage | None]) – 接受存储对象和设备字符串,并返回适当设备上的存储对象或 None 的可调用对象。
- 返回:
无
示例
>>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- torch.serialization.get_crc32_options()[source]#
获取
torch.save()是否为每个记录计算并写入 crc32。默认为
True。- 返回类型:
- torch.serialization.set_crc32_options(compute_crc32)[source]#
设置
torch.save()是否为每个记录计算并写入 crc32。注意
将此设置为
False可能会导致torch.save输出的解压失败或因 CRC32 损坏而发出警告。但是torch.load仍然可以加载文件。- 参数:
compute_crc32 (bool) – 设置 crc32 计算标志。
- torch.serialization.get_default_load_endianness()[source]#
获取加载文件的备用字节序。
如果保存的检查点中不存在字节序标记,则使用此字节序作为备用。默认情况下,它是“本机”字节序。
- 返回:
Optional[LoadEndianness]
- 返回类型:
default_load_endian
- torch.serialization.set_default_load_endianness(endianness)[source]#
设置加载文件的备用字节序。
如果保存的检查点中不存在字节序标记,则使用此字节序作为备用。默认情况下,它是“本机”字节序。
- 参数:
endianness – 新的备用字节序。
- torch.serialization.get_default_mmap_options()[source]#
获取
mmap=True时torch.load()的默认 mmap 选项。默认为
mmap.MAP_PRIVATE。- 返回:
int
- 返回类型:
default_mmap_options
- torch.serialization.set_default_mmap_options(flags)[source]#
用于设置
mmap=True时torch.load()的默认 mmap 选项到 flags 的上下文管理器或函数。目前仅支持
mmap.MAP_PRIVATE或mmap.MAP_SHARED。如果您需要添加其他选项,请提交一个 issue。注意
此功能目前不支持 Windows。
- 参数:
flags (int) –
mmap.MAP_PRIVATE或mmap.MAP_SHARED。
- torch.serialization.add_safe_globals(safe_globals)[source]#
将给定的全局变量标记为
weights_only加载是安全的。例如,添加到此列表的函数可以在 unpickling 过程中被调用,类可以被实例化并设置状态。列表中的每个项目都可以是函数/类,或者是一个元组,形式为 (函数/类, 字符串),其中字符串是函数/类的完整路径。
在序列化格式中,每个函数都用其完整路径标识为
{__module__}.{__qualname__}。调用此 API 时,您可以提供与检查点中匹配的完整路径,否则将使用默认的{fn.__module__}.{fn.__qualname__}。- 参数:
safe_globals (List[Union[Callable, Tuple[Callable, str]]]) – 要标记为安全的全局变量列表。
示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... torch.serialization.add_safe_globals([MyTensor]) ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]])
- torch.serialization.get_unsafe_globals_in_checkpoint(f)[源码]#
返回一个字符串列表,列出
torch.save对象中不适合weights_only加载的函数/类。对于给定的函数或类
f,对应的字符串格式为{f.__module__}.{f.__name__}。此函数将返回检查点中所有未被标记为
weights_only安全的 GLOBALs(通过add_safe_globals()或safe_globals上下文,或默认被torch允许列表)。注意
此函数将静态反汇编检查点中的 pickle 文件。这意味着在反序列化过程中动态推入堆栈的任何类都不会包含在输出中。
- class torch.serialization.safe_globals(safe_globals)[源码]#
上下文管理器,用于将某些 globals 添加为
weights_only加载的安全项。示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... with torch.serialization.safe_globals([MyTensor]): ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]]) >>> assert torch.serialization.get_safe_globals() == []
- class torch.serialization.skip_data(materialize_fake_tensors=False)[源码]#
上下文管理器,用于在
torch.save/torch.load调用中跳过存储字节的写入/读取。对于保存路径,存储仍然会被保存,但其字节原本要写入的空间将是空的。存储字节可以在后续的传递中填充。
对于加载路径,张量将根据检查点加载,但其存储不会被数据填充。
警告
skip_data上下文管理器是一个早期原型,可能会发生更改。- 参数:
materialize_fake_tensors (bool) – 在保存过程中是否实例化 FakeTensors。这在加载路径中是无操作。
示例
>>> import tempfile >>> t = torch.randn(2, 3) >>> with tempfile.NamedTemporaryFile() as f: ... with torch.serialization.skip_data(): ... torch.save(t, f.name) ... torch.load(f.name, weights_only=True) tensor([[0., 0., 0.], [0., 0., 0.]])
配置#
torch.utils.serialization.config 提供了可以控制 torch.save 和 torch.load 行为的全局配置。
torch.utils.serialization.config.save 包含控制 torch.save 行为的选项。
compute_crc32:是否计算和写入 zip 文件校验和(默认:True)。请参阅set_crc32_options()。
use_pinned_memory_for_d2h:对于在传递给torch.save时位于加速器上的存储,是否在torch.save中将存储移动到 CPU 的固定内存或可分页内存。(默认:False(即可分页))
storage_alignment:在torch.save期间,检查点中存储的字节对齐。(默认64)
torch.utils.serialization.config.load 包含控制 torch.load 行为的选项。
mmap:请参阅torch.load()中mmap参数的文档。此配置将设置torch.load的mmap行为,前提是它尚未被显式传递给torch.load调用。(默认:False)。
endianness:请参阅set_default_load_endianness()。(默认:torch.serialization.LoadEndianness.NATIVE)
mmap_flags:请参阅set_default_mmap_options。(默认:MAP_PRIVATE)
calculate_storage_offsets:如果此配置设置为True,则在使用torch.load(mmap=True)时,将计算存储的偏移量,而不是通过随机读取来获取。这最大限度地减少了随机读取,当文件通过网络加载时可能很有用。(默认:False)