评价此页

torch.jit.save#

torch.jit.save(m, f, _extra_files=None)[source]#

将此模块的离线版本保存起来,以便在单独的进程中使用。

保存的模块会序列化该模块的所有方法、子模块、参数和属性。它可以使用 torch::jit::load(filename) 加载到 C++ API 中,或者使用 torch.jit.load 加载到 Python API 中。

为了能够保存模块,它不能调用任何原生 Python 函数。这意味着所有子模块都必须是 ScriptModule 的子类。

Danger

所有模块,无论其设备是什么,在加载时都会被加载到 CPU 上。这与 torch.load() 的语义不同,并且未来可能会发生变化。

参数
  • m – 要保存的 ScriptModule

  • f – 一个类文件对象(必须实现 write 和 flush 方法)或一个包含文件名的字符串。

  • _extra_files – 一个文件名到内容的映射,这些内容将作为 f 的一部分进行存储。

注意

torch.jit.save 会尝试在不同版本之间保留某些运算符的行为。例如,在 PyTorch 1.5 中,两个整数张量相除会执行整除;如果在 PyTorch 1.5 中保存了包含该代码的模块,并在 PyTorch 1.6 中加载,其除法行为将得以保留。然而,在 PyTorch 1.6 中保存的相同模块在 PyTorch 1.5 中将无法加载,因为除法行为在 1.6 中发生了变化,而 1.5 无法复制 1.6 的行为。

示例: .. testcode

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

m = torch.jit.script(MyModule())

# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)