评价此页

torch.load#

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)[source]#

从文件中加载使用 torch.save() 保存的对象。

torch.load() 使用 Python 的反序列化功能,但会特殊处理张量底层的存储。它们首先在 CPU 上反序列化,然后移动到保存时的设备。如果失败(例如,因为运行时系统没有某些设备),则会引发异常。但是,可以使用 map_location 参数动态地将存储重新映射到备用设备集。

如果 map_location 是一个可调用对象,它将为每个序列化的存储调用一次,有两个参数:存储和位置。存储参数将是存储的初始反序列化,驻留在 CPU 上。每个序列化的存储都有一个关联的位置标签,该标签标识了它保存的设备,并且该标签是传递给 map_location 的第二个参数。内置的位置标签是 CPU 张量的 'cpu',CUDA 张量的 'cuda:device_id'(例如 'cuda:2')。map_location 应返回 None 或一个存储。如果 map_location 返回一个存储,它将被用作最终反序列化的对象,并已移动到正确的设备。否则,torch.load() 将回退到默认行为,就好像没有指定 map_location 一样。

如果 map_location 是一个 torch.device 对象或包含设备标签的字符串,它将指示所有张量应加载到的位置。

否则,如果 map_location 是一个字典,它将用于将文件中出现的位置标签(键)重新映射到指定存储位置(值)的标签。

用户扩展可以使用 torch.serialization.register_package() 注册自己的位置标签以及标记和反序列化方法。

有关操作 checkpoint 的更高级工具,请参阅 布局控制

参数
  • f (Union[str, PathLike[str], IO[bytes]]) – 一个类似文件的对象(必须实现 read()readline()tell()seek()),或包含文件名的字符串或 os.PathLike 对象

  • map_location (Optional[Union[Callable[[Storage, str], Storage], device, str, dict[str, str]]]) – 一个函数、torch.device、字符串或字典,用于指定如何重新映射存储位置

  • pickle_module (Optional[Any]) – 用于反序列化元数据和对象的模块(必须与序列化文件时使用的 pickle_module 匹配)

  • weights_only (Optional[bool]) – 指示反序列化器是否应仅限于加载张量、基本类型、字典以及通过 torch.serialization.add_safe_globals() 添加的任何类型。有关更多详细信息,请参阅 torch.load with weights_only=True

  • mmap (Optional[bool]) – 指示是否应映射文件而不是将所有存储加载到内存中。通常,文件中的张量存储将首先从磁盘移动到 CPU 内存,然后移动到保存时标记的设备,或者 map_location 指定的设备。如果最终位置是 CPU,则此第二步为空操作。当设置 mmap 标志时,而不是在第一步将张量存储从磁盘复制到 CPU 内存,f 将被映射,这意味着张量存储将在访问其数据时惰性加载。

  • pickle_load_args (Any) – (仅限 Python 3) 传递给 pickle_module.load()pickle_module.Unpickler() 的可选关键字参数,例如 errors=...

返回类型

任何

警告

torch.load(),除非 weights_only 参数设置为 True,否则会隐式使用 pickle 模块,该模块已知不安全。有可能构造恶意的 pickle 数据,这些数据将在反序列化过程中执行任意代码。切勿在不安全模式下加载可能来自不受信任来源或可能被篡改的数据。**仅加载您信任的数据**。

注意

当你对包含 GPU 张量的文件调用 torch.load() 时,默认情况下,这些张量将被加载到 GPU。你可以调用 torch.load(.., map_location='cpu') 然后调用 load_state_dict() 来避免在加载模型 checkpoint 时出现 GPU 内存激增。

注意

默认情况下,我们将字节字符串解码为 utf-8。这是为了避免在 Python 3 中加载 Python 2 保存的文件时出现常见的错误情况 UnicodeDecodeError: 'ascii' codec can't decode byte 0x...。如果此默认值不正确,你可以使用额外的 encoding 关键字参数来指定如何加载这些对象,例如 encoding='latin1' 使用 latin1 编码将它们解码为字符串,而 encoding='bytes' 将它们保留为字节数组,稍后可以用 byte_array.decode(...) 解码。

示例

>>> torch.load("tensors.pt", weights_only=True)
# Load all tensors onto the CPU
>>> torch.load(
...     "tensors.pt",
...     map_location=torch.device("cpu"),
...     weights_only=True,
... )
# Load all tensors onto the CPU, using a function
>>> torch.load(
...     "tensors.pt",
...     map_location=lambda storage, loc: storage,
...     weights_only=True,
... )
# Load all tensors onto GPU 1
>>> torch.load(
...     "tensors.pt",
...     map_location=lambda storage, loc: storage.cuda(1),
...     weights_only=True,
... )  # type: ignore[attr-defined]
# Map tensors from GPU 1 to GPU 0
>>> torch.load(
...     "tensors.pt",
...     map_location={"cuda:1": "cuda:0"},
...     weights_only=True,
... )
# Load tensor from io.BytesIO object
# Loading from a buffer setting weights_only=False, warning this can be unsafe
>>> with open("tensor.pt", "rb") as f:
...     buffer = io.BytesIO(f.read())
>>> torch.load(buffer, weights_only=False)
# Load a module with 'ascii' encoding for unpickling
# Loading from a module setting weights_only=False, warning this can be unsafe
>>> torch.load("module.pt", encoding="ascii", weights_only=False)