评价此页

torch.jit.trace_module#

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[source]#

将一个模块进行追踪,并返回一个经过即时编译优化的可执行的 ScriptModule

当一个模块被传递给 torch.jit.trace 时,只有 forward 方法会被运行和追踪。使用 trace_module,你可以指定一个方法名到示例输入的字典来追踪(参阅下面的 inputs 参数)。

有关追踪的更多信息,请参阅 torch.jit.trace

参数
  • mod (torch.nn.Module) – 一个 torch.nn.Module,其中包含的方法名在 inputs 中指定。给定的方法将被编译为一个单独的 ScriptModule 的一部分。

  • inputs (dict) – 一个字典,包含按方法名索引的样本输入,这些方法名在 mod 中。在追踪时,输入将被传递给方法,方法名与输入键相对应。{ 'forward' : example_forward_input, 'method2': example_method2_input}

关键字参数
  • check_trace (bool, optional) – 检查相同的输入运行追踪后的代码是否产生相同的输出。默认值:True。你可能希望禁用此项,例如,如果你的网络包含非确定性操作,或者如果你确定网络在检查器失败的情况下仍然是正确的。

  • check_inputs (list of dicts, optional) – 一系列字典,其中包含用于将追踪结果与预期进行比较的输入参数。每个元组等同于一组将被指定在 inputs 中的输入参数。为获得最佳结果,请传递一组具有代表性的检查输入,以覆盖网络将看到的输入形状和类型。如果未指定,则使用原始 inputs 进行检查。

  • check_tolerance (float, optional) – 在检查器过程中使用的浮点数比较容差。当由于已知原因(如算子融合)导致结果在数值上出现分歧时,可以使用此参数来放宽检查器的严格性。

  • example_inputs_is_kwarg (bool, optional) – 此参数指示示例输入是否为关键字参数的集合。默认值:False

返回

一个 ScriptModule 对象,其中包含一个跟踪代码的 forward 方法。当 func 是一个 torch.nn.Module 时,返回的 ScriptModule 将具有与 func 相同的子模块和参数集。

示例(跟踪具有多个方法的模块)

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {
    "forward": example_forward_input,
    "weighted_kernel_sum": example_weight,
}
module = torch.jit.trace_module(n, inputs)