评价此页

torch.jit.trace#

torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[source]#

跟踪函数并返回一个可执行的或 ScriptFunction,它将使用即时编译进行优化。

跟踪最适合仅处理 Tensor 和包含 Tensor 的列表、字典和元组的代码。

使用 torch.jit.tracetorch.jit.trace_module,您可以将现有的模块或 Python 函数转换为 TorchScript ScriptFunctionScriptModule。您必须提供示例输入,我们会运行该函数,记录所有张量上执行的操作。

  • 对独立函数进行的由此产生的记录会生成 ScriptFunction

  • nn.Module.forwardnn.Module 进行由此产生的记录会生成 ScriptModule

此模块还包含原始模块的任何参数。

警告

跟踪只能正确记录不依赖于数据的函数和模块(例如,不包含张量中数据的条件判断),并且不包含任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)。跟踪仅记录给定函数在给定张量上运行时执行的操作。因此,返回的 ScriptModule 将始终在任何输入上运行相同的跟踪图。当您的模块预期运行不同操作集时,这有一些重要的含义,具体取决于输入和/或模块状态。例如,

  • 跟踪不会记录任何控制流,例如 if 语句或循环。当该控制流在您的模块中是常量时,这是没问题的,并且它通常会内联控制流决策。但有时控制流实际上是模型本身的一部分。例如,循环神经网络是围绕输入序列(可能动态)长度的循环。

  • 在返回的 ScriptModule 中,在 trainingeval 模式下行为不同的操作将始终表现为在跟踪期间处于的模式,无论 ScriptModule 处于哪种模式。

在这些情况下,跟踪将不适用,而 脚本化 是更好的选择。如果您跟踪这些模型,您可能会在后续调用模型时默默地获得不正确的结果。跟踪器会尝试在执行可能导致生成错误跟踪的操作时发出警告。

参数

func (callable or torch.nn.Module) – 将使用 example_inputs 运行的 Python 函数或 torch.nn.Modulefunc 的参数和返回值必须是张量,或包含张量的(可能嵌套的)元组。当将模块传递给 torch.jit.trace 时,仅运行和跟踪 forward 方法(有关详细信息,请参阅 torch.jit.trace)。

关键字参数
  • example_inputs (tuple or torch.Tensor or None, optional) – 在跟踪时将传递给函数的示例输入元组。默认为 None。应指定此参数或 example_kwarg_inputs。生成的跟踪可以与不同类型和形状的输入一起运行,前提是跟踪的操作支持这些类型和形状。 example_inputs 也可以是单个张量,在这种情况下它会自动包装在元组中。当值为 None 时,应指定 example_kwarg_inputs

  • check_trace (bool, optional) – 检查通过跟踪代码运行的相同输入是否产生相同的输出。默认为 True。如果您需要禁用此选项,例如,如果您的网络包含非确定性操作,或者如果您确信网络是正确的(尽管检查器失败)。

  • check_inputs (list of tuples, optional) – 用于将跟踪与预期进行比较的一组输入参数的元组列表。每个元组相当于在 example_inputs 中指定的输入参数集。为了获得最佳结果,请传入一组代表网络预期输入的形状和类型空间的检查输入。如果未指定,则使用原始 example_inputs 进行检查。

  • check_tolerance (float, optional) – 检查器过程中使用的浮点数比较容差。在已知原因(例如,算子融合)导致结果在数值上发生分歧的情况下,可以使用此选项来放宽检查器的严格性。

  • strict (bool, optional) – 以严格模式运行跟踪器或不运行(默认为 True)。仅当您希望跟踪器记录您的可变容器类型(当前是 list/dict)并且您确信您在问题中使用的容器是 constant 结构并且不被用作控制流(if, for)条件时,才将其关闭。

  • example_kwarg_inputs (dict, optional) – 此参数是在跟踪时传递给函数的示例输入的关键字参数包。默认为 None。应指定此参数或 example_inputs。字典将通过跟踪函数的参数名称进行解包。如果字典的键与跟踪函数的参数名称不匹配,将引发运行时异常。

返回

如果 funcnn.Modulenn.Moduleforward,则 trace 返回一个具有单个 forward 方法的 ScriptModule 对象,该方法包含跟踪的代码。返回的 ScriptModule 将具有与原始 nn.Module 相同的子模块和参数集。如果 func 是独立函数,则 trace 返回 ScriptFunction

示例(跟踪函数)

import torch

def foo(x, y):
    return 2 * x + y

# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment

示例(跟踪现有模块)

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)