torch.jit.trace_module#
- torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[source]#
跟踪一个模块并返回一个通过即时编译优化的可执行
ScriptModule
。当一个模块被传递给
torch.jit.trace
时,只有forward
方法会被运行和跟踪。使用trace_module
,您可以指定一个方法名到示例输入的字典来跟踪(请参阅下面的inputs
参数)。有关跟踪的更多信息,请参阅
torch.jit.trace
。- 参数
mod (torch.nn.Module) – 一个
torch.nn.Module
,其中包含在inputs
中指定的方法名。给定的方法将被编译为单个 ScriptModule 的一部分。inputs (dict) – 一个包含由
mod
中方法名索引的示例输入的字典。在跟踪时,输入将被传递给名称与输入的键对应的那些方法。{ 'forward' : example_forward_input, 'method2': example_method2_input}
- 关键字参数
check_trace (
bool
, optional) – 检查相同的输入通过跟踪代码是否产生相同的输出。默认值:True
。如果您希望禁用此选项,例如,因为您的网络包含非确定性操作,或者如果您确信网络是正确的,尽管检查器失败。check_inputs (list of dicts, optional) – 一个输入参数字典列表,用于根据预期检查跟踪。每个元组等同于在
inputs
中指定的一组输入参数。为了获得最佳结果,请提供一组能够代表网络预期输入的形状和类型的检查输入。如果未指定,则使用原始inputs
进行检查。check_tolerance (float, optional) – 在检查器过程中使用的浮点数比较容差。当由于已知原因(例如操作符融合)导致结果在数值上出现偏差时,可以使用此参数来放宽检查器的严格性。
example_inputs_is_kwarg (
bool
, optional) – 此参数指示示例输入是否为关键字参数的集合。默认值:False
。
- 返回
一个
ScriptModule
对象,其中包含一个forward
方法,该方法包含跟踪的代码。当func
是一个torch.nn.Module
时,返回的ScriptModule
将具有与func
相同的子模块和参数集。
示例(跟踪具有多个方法的模块)
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) def weighted_kernel_sum(self, weight): return weight * self.conv.weight n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods inputs = { "forward": example_forward_input, "weighted_kernel_sum": example_weight, } module = torch.jit.trace_module(n, inputs)