torch.jit.script#
- torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)[source]#
脚本化函数。
脚本化函数或 `nn.Module` 会检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码,并返回一个 `ScriptModule` 或 `ScriptFunction`。TorchScript 本身是 Python 语言的一个子集,因此并非所有 Python 功能都可用,但我们提供了足够的功能来在张量上进行计算和执行依赖于控制的操作。有关完整指南,请参阅 TorchScript 语言参考。
脚本化字典或列表会将其内部数据复制到 TorchScript 实例中,然后该实例可以在 Python 和 TorchScript 之间以零拷贝的开销进行引用传递。
- `torch.jit.script` 可以用作模块、函数、字典和列表的函数,
也可以用作 `TorchScript 类` 和函数的装饰器 `@torch.jit.script`。
- 参数
**obj** (Callable, class, or nn.Module) – 要编译的 `nn.Module`、函数、类类型、字典或列表。
**example_inputs** (Union[List[Tuple], Dict[Callable, List[Tuple]], None]) – 提供示例输入以注解函数或 `nn.Module` 的参数。
- 返回
如果 `obj` 是 `nn.Module`,`script` 返回一个 `ScriptModule` 对象。返回的 `ScriptModule` 将具有与原始 `nn.Module` 相同的子模块和参数集。如果 `obj` 是独立函数,则返回 `ScriptFunction`。如果 `obj` 是 `dict`,则 `script` 返回 `torch._C.ScriptDict` 的实例。如果 `obj` 是 `list`,则 `script` 返回 `torch._C.ScriptList` 的实例。
- 脚本化函数
装饰器 `@torch.jit.script` 将通过编译函数体来构建 `ScriptFunction`。
示例(脚本化函数)
import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2))
- **使用 example_inputs 脚本化函数
示例输入可用于注解函数参数。
示例(脚本化前的函数注解)
import torch def test_sum(a, b): return a + b # Annotate the arguments to be int scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) print(type(scripted_fn)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(scripted_fn.code) # Call the function using the TorchScript interpreter scripted_fn(20, 100)
- 脚本化 nn.Module
默认情况下,通过脚本化 `nn.Module` 将编译 `forward` 方法,并递归编译 `forward` 调用或子模块、子模块和函数。如果 `nn.Module` 只使用了 TorchScript 支持的特性,则无需修改原始模块代码。`script` 将构建一个 `ScriptModule`,它具有原始模块的属性、参数和方法的副本。
示例(脚本化带 Parameter 的简单模块)
import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super().__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3))
示例(脚本化带被跟踪的子模块的模块)
import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self) -> None: super().__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule())
要编译 `forward` 以外的方法(并递归编译它调用的任何内容),请将 `@torch.jit.export` 装饰器添加到该方法。要选择退出编译,请使用 `@torch.jit.ignore` 或 `@torch.jit.unused`。
示例(模块中的导出和忽略方法)
import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self) -> None: super().__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2)))
示例(使用 example_inputs 注解 nn.Module 的 forward)
import torch import torch.nn as nn from typing import NamedTuple class MyModule(NamedTuple): result: List[int] class TestNNModule(torch.nn.Module): def forward(self, a) -> MyModule: result = MyModule(result=a) return result pdt_model = TestNNModule() # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) # Run the scripted_model with actual inputs print(scripted_model([20]))