评价此页

torch.fx#

创建日期:2020年12月15日 | 最后更新日期:2025年12月19日

概述#

FX 是一个供开发人员用于转换 nn.Module 实例的工具包。FX 由三个主要组件组成:符号追踪器 (symbolic tracer)中间表示 (intermediate representation)Python 代码生成 (Python code generation)。以下是这些组件在实际运行中的演示。

import torch


# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)


module = MyModule()

from torch.fx import symbolic_trace

# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号追踪器对 Python 代码执行“符号执行”。它在代码中传递虚假值(称为 Proxy)。这些 Proxy 上的操作会被记录下来。有关符号追踪的更多信息,请参见 symbolic_trace()Tracer 文档。

中间表示 (IR) 是符号追踪期间记录的操作的容器。它由一系列节点 (Nodes) 组成,这些节点代表函数输入、调用点(针对函数、方法或 torch.nn.Module 实例)和返回值。有关 IR 的更多信息可以在 Graph 的文档中找到。IR 是应用转换的格式。

Python 代码生成使 FX 成为一个“Python 到 Python”(或“Module 到 Module”)的转换工具包。对于每个 Graph IR,我们可以创建符合 Graph 语义的有效 Python 代码。此功能封装在 GraphModule 中,它是一个 torch.nn.Module 实例,持有一个 Graph 以及从该 Graph 生成的 forward 方法。

综合来看,这个组件管道(符号追踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。此外,这些组件可以单独使用。例如,符号追踪可以独立使用,以捕获用于分析(而非转换)目的的代码形式。代码生成可用于程序化生成模型,例如从配置文件生成。FX 有许多用途!

可以在 examples 仓库中找到几个转换示例。

编写转换#

什么是 FX 转换?本质上,它是一个长这样的函数。


import torch
import torch.fx

def transform(m: nn.Module,
                tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: Modify this Graph or create a new one
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

你的转换将接收一个 torch.nn.Module,从中获取一个 Graph,进行一些修改,并返回一个新的 torch.nn.Module。你应该将 FX 转换返回的 torch.nn.Module 视为与常规 torch.nn.Module 相同——你可以将其传递给另一个 FX 转换,也可以运行它。确保 FX 转换的输入和输出都是 torch.nn.Module 将有助于实现可组合性。

注意

也可以修改现有的 GraphModule 而不是创建一个新的,如下所示

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # Modify gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

注意,你必须调用 GraphModule.recompile() 才能使 GraphModule 上生成的 forward() 方法与修改后的 Graph 保持同步。

假设你已经传入了一个被追踪为 Graphtorch.nn.Module,现在有两种主要方法可以构建新的 Graph

图 (Graphs) 快速入门#

有关图语义的完整说明可以在 Graph 文档中找到,但我们在这里将介绍一些基础知识。 Graph 是代表 GraphModule 上某个方法的数据结构。它所需的信息包括

  • 方法的输入是什么?

  • 方法内部运行的操作是什么?

  • 方法的输出(即返回)值是什么?

这三个概念都由 Node 实例表示。让我们通过一个简短的例子来看看这意味着什么


import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

这里我们定义一个用于演示的模块 MyModule,实例化它,进行符号追踪,然后调用 Graph.print_tabular() 方法打印出一个显示该 Graph 节点的表格

opcode (操作码)

name (名称)

target (目标)

args (参数)

kwargs (关键字参数)

placeholder (占位符)

x

x

()

{}

get_attr (获取属性)

linear_weight

linear.weight

()

{}

call_function (调用函数)

add_1

(x, linear_weight)

{}

call_module (调用模块)

linear_1

linear

(add_1,)

{}

call_method (调用方法)

relu_1

relu

(linear_1,)

{}

call_function (调用函数)

sum_1

<内置方法 sum …>

(relu_1,)

{‘dim’: -1}

call_function (调用函数)

topk_1

<内置方法 topk …>

(sum_1, 3)

{}

output

output

output

(topk_1,)

{}

我们可以利用这些信息来回答我们上面提出的问题。

  • 方法的输入是什么?在 FX 中,方法输入通过特殊的 placeholder 节点指定。在这种情况下,我们有一个 targetx 的单一 placeholder 节点,意味着我们有一个名为 x 的单一(非 self)参数。

  • 方法内部的操作是什么?get_attrcall_functioncall_modulecall_method 节点代表方法中的操作。有关所有这些节点语义的完整说明可以在 Node 文档中找到。

  • 方法的返回值是什么?Graph 中的返回值由特殊的 output 节点指定。

既然我们已经了解了 FX 中代码表示的基础知识,我们现在可以探索如何编辑 Graph

图操作 (Graph Manipulation)#

直接图操作#

构建新 Graph 的一种方法是直接操作旧图。为了实现这一点,我们可以简单地提取符号追踪获得的 Graph 并对其进行修改。例如,假设我们希望将 torch.add() 调用替换为 torch.mul() 调用。


import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
                tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                    # Graph is well-formed.

    return fx.GraphModule(m, graph)

我们还可以执行更复杂的 Graph 重写,例如删除或追加节点。为了辅助这些转换,FX 在 Graph 文档中提供了转换图的实用函数。下面可以找到一个使用这些 API 追加 torch.relu() 调用的示例。


# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

对于仅由替换组成的简单转换,你还可以使用 子图重写器 (subgraph rewriter)

使用 replace_pattern() 进行子图重写#

FX 在直接图操作的基础上还提供了另一层自动化。 replace_pattern() API 本质上是一个用于编辑 Graph 的“查找/替换”工具。它允许你指定一个 pattern (模式) 函数和 replacement (替换) 函数,它会追踪这些函数,在 pattern 图中查找操作组的实例,并将这些实例替换为 replacement 图的副本。这可以极大地自动化繁琐的图操作代码,随着转换变得越来越复杂,这些代码可能会变得难以管理。

图操作示例#

Proxy/重新追踪 (Proxy/Retracing)#

操作 Graph 的另一种方式是重复使用符号追踪中使用的 Proxy 机制。例如,假设我们想编写一个将 PyTorch 函数分解为更小操作的转换。它会将每个 F.relu(x) 调用转换为 (x > 0) * x。一种可能性是执行必要的图重写,在 F.relu 之后插入比较和乘法,然后清理原始的 F.relu。但是,我们可以通过使用 Proxy 对象自动将操作记录到 Graph 中来自动化此过程。

要使用此方法,我们将要插入的操作编写为常规 PyTorch 代码,并使用 Proxy 对象作为参数调用该代码。这些 Proxy 对象将捕获对它们执行的操作并将其追加到 Graph 中。


# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
                tracer_class : type = fx.Tracer) -> torch.nn.Module:
    """
    Decompose `model` into smaller constituent operations.
    Currently,this only supports decomposing ReLU into its
    mathematical definition: (x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # By wrapping the arguments with proxies,
            # we can dispatch to the appropriate
            # decomposition rule and implicitly add it
            # to the Graph by symbolically tracing it.
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # Operations on `Proxy` always yield new `Proxy`s, and the
            # return value of our decomposition rule is no exception.
            # We need to extract the underlying `Node` from the `Proxy`
            # to use it in subsequent iterations of this transform.
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # Default case: we don't have a decomposition rule for this
            # node, so just copy the node over into the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式的图操作外,使用 Proxy 还允许你将重写规则指定为原生 Python 代码。对于需要大量重写规则的转换(如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。请注意,在调用 Proxy 时,我们还传递了一个指向底层变量 graph 的追踪器。这样做是为了防止图中的操作是多元的(例如 add 是一个二元运算符)时,对 Proxy 的调用创建多个图追踪器实例,从而导致意外的运行时错误。我们建议使用这种 Proxy 使用方法,特别是当底层操作符不能被安全地假设为一元操作符时。

有关使用 Proxy 进行 Graph 操作的一个实例,可以参考 这里

解释器模式 (The Interpreter Pattern)#

FX 中一个非常有用的代码组织模式是遍历 Graph 中的所有 Node 并执行它们。这可以用于多种用途,包括流经图的值的运行时分析,或通过使用 Proxy 进行重新追踪来转换代码。例如,假设我们想运行一个 GraphModule,并记录我们在运行时看到的节点上的 torch.Tensor 形状和数据类型属性。代码可能如下所示


import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':

                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

如你所见,一个完整的 FX 解释器并不复杂,但它非常有用。为了方便使用这种模式,我们提供了 Interpreter 类,它封装了上述逻辑,使得解释器执行的某些方面可以通过方法重写来改变。

除了执行操作外,我们还可以通过在解释器中传递 Proxy 值来生成新的 Graph。同样,我们提供了 Transformer 类来涵盖这种模式。 Transformer 的行为与 Interpreter 类似,但你不是调用 run 方法从模块中获取具体的输出值,而是调用 Transformer.transform() 方法来返回一个新的 GraphModule,该模块受到你作为重写方法安装的任何转换规则的影响。

解释器模式示例#

调试#

简介#

在编写转换的过程中,我们的代码经常会出现一些偏差。在这种情况下,我们可能需要进行调试。关键是要倒推:首先,检查调用生成的模块的结果以证明或反驳正确性。然后,检查并调试生成的代码。最后,调试导致生成代码的转换过程。

如果你不熟悉调试器,请参阅辅助章节 可用调试器

转换编写中的常见陷阱#

  • 非确定性的 set 迭代顺序。在 Python 中,set 数据类型是无序的。例如,使用 set 来包含诸如 Node 之类的对象集合,可能会导致意外的非确定性。一个例子是迭代一个 Node 集合以将它们插入到 Graph 中。由于 set 数据类型是无序的,输出程序中操作的顺序将是非确定性的,并且在不同程序调用之间可能会发生变化。建议的替代方案是使用 dict 数据类型,自 Python 3.7 起(以及 cPython 3.6 起),它是按插入顺序排列的。通过将要去重的值存储在 dict 的键中,dict 的使用效果等同于 set。

检查模块的正确性#

因为大多数深度学习模块的输出由浮点数 torch.Tensor 实例组成,所以检查两个 torch.nn.Module 的结果是否等价并不像进行简单的相等检查那么直接。为了说明这一点,让我们看一个例子


import torch
import torch.fx
import torchvision.models as models

def transform(m : torch.nn.Module) -> torch.nn.Module:
    gm = torch.fx.symbolic_trace(m)

    # Imagine we're doing some transforms here
    # <...>

    gm.recompile()

    return gm

resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)

input_image = torch.randn(5, 3, 224, 224)

assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""

在这里,我们尝试使用 == 相等运算符检查两个深度学习模型值的相等性。然而,这并不是定义良好的,
原因既包括该运算符返回的是张量而不是布尔值,也因为浮点值的比较应该使用误差范围(或 epsilon)来解释浮点运算的非交换性(有关更多详细信息,请参见此处)。我们可以改用 torch.allclose(),它将考虑到相对和绝对容差阈值,为我们提供近似比较

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

这是我们工具箱中检查转换后的模块与参考实现相比行为是否符合预期的第一个工具。

调试生成的代码#

由于 FX 在 GraphModule 上生成 forward() 函数,使用传统的调试技术(如 print 语句或 pdb)并不那么直接。幸运的是,我们有几种技术可以用来调试生成的代码。

使用 pdb#

调用 pdb 以步进式进入正在运行的程序。尽管代表 Graph 的代码不在任何源文件中,但当调用前向传播时,我们仍然可以使用 pdb 手动步入其中。


import torch
import torch.fx
import torchvision.models as models

def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(inp)
    # Transformation logic here
    # <...>

    # Return new Module
    return fx.GraphModule(inp, graph)

my_module = models.resnet18()
my_module_transformed = my_pass(my_module)

input_value = torch.randn(5, 3, 224, 224)

# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()

my_module_transformed(input_value)

使用 GraphModuleto_folder 函数#

GraphModule.to_folder()GraphModule 中的一个方法,它允许你将生成的 FX 代码导出到一个文件夹。虽然像在 打印生成的代码 中那样将前向传递复制代码通常就足够了,但使用 to_folder 检查模块和参数可能会更容易。


m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

在运行上述示例后,我们可以查看 foo/module.py 中的代码,并根据需要进行修改(例如添加 print 语句或使用 pdb)以调试生成的代码。

调试转换过程#

既然我们已经确定转换正在创建不正确的代码,那么是时候调试转换本身了。首先,我们将检查文档中的 符号追踪的局限性 部分。一旦我们验证追踪按预期工作,目标就变成了找出在我们的 GraphModule 转换过程中出了什么问题。在 编写转换 中可能会有快速解答,但如果没有,有几种方法可以检查我们的被追踪模块


# Sample Module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

# Create an instance of `M`
m = M()

# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)

# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
    add = x + y;  x = y = None
    return add
"""

# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    return add
"""

# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add>  (x, y)  {}
output         output  output                   (add,)  {}
"""

使用上面的实用函数,我们可以对比应用转换前后的追踪模块。有时,简单的视觉对比就足以追踪到错误。如果仍然不清楚哪里出了问题,像 pdb 这样的调试器可能是下一个不错的步骤。

根据上面的示例,考虑以下代码


# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    # Get the Graph from our traced Module
    g = tracer_class().trace(module)

    """
    Transformations on `g` go here
    """

    return fx.GraphModule(module, g)

# Transform the Graph
transformed = transform_graph(traced)

# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

使用上面的例子,假设对 print(traced) 的调用向我们展示了转换中存在错误。我们想使用调试器找到出问题的地方。我们开始一个 pdb 会话。我们可以通过在 transform_graph(traced) 处设置断点,然后按 s “步入”对 transform_graph(traced) 的调用,来查看转换期间发生了什么。

我们也可以尝试修改 print_tabular 方法来打印图中节点的不同属性。(例如,我们可能想查看节点的 input_nodesusers。)

可用调试器#

最常用的 Python 调试器是 pdb。你可以通过在命令行输入 python -m pdb FILENAME.py 以“调试模式”启动程序,其中 FILENAME 是你要调试的文件名。之后,你可以使用 pdb 调试器命令 来逐步执行正在运行的程序。通常在启动 pdb 时设置断点 (b LINE-NUMBER),然后调用 c 以运行程序直至该点。这可以避免你必须逐步执行每一行(使用 sn)才能到达你想检查的代码部分。或者,你可以在想要中断的行之前编写 import pdb; pdb.set_trace()。如果你添加了 pdb.set_trace(),你的程序在运行时将自动以调试模式启动。(换句话说,你只需在命令行输入 python FILENAME.py 而不是 python -m pdb FILENAME.py。)一旦你在调试模式下运行文件,你就可以逐步执行代码并使用某些命令检查程序的内部状态。网上有很多关于 pdb 的优秀教程,包括 RealPython 的 “使用 Pdb 进行 Python 调试”

PyCharm 或 VSCode 等 IDE 通常内置了调试器。在你的 IDE 中,你可以选择 a) 通过在 IDE 中打开终端窗口使用 pdb(例如 VSCode 中的“查看”→“终端”),或者 b) 使用内置调试器(通常是 pdb 的图形包装器)。

符号追踪的局限性#

FX 使用一套符号追踪(又称 符号执行)系统来捕获可转换/可分析形式的程序语义。该系统之所以是追踪 (tracing),是因为它通过执行程序(实际上是 torch.nn.Module 或函数)来记录操作。它又是符号 (symbolic) 的,因为在执行期间流经程序的数据不是真实数据,而是符号(在 FX 术语中称为 Proxy)。

尽管符号追踪适用于大多数神经网络代码,但它也有一些局限性。

动态控制流#

符号追踪的主要局限性是它目前不支持动态控制流。也就是说,循环或 if 语句的条件可能取决于程序的输入值。

例如,让我们看下面的程序


def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:

        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
    <...>
    File "dyn.py", line 6, in func_to_trace
    if x.sum() > 0:
    File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
    File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

if 语句的条件依赖于 x.sum() 的值,而该值又依赖于函数输入 x。由于 x 可以变化(即,如果你向被追踪函数传递一个新的输入张量),这就是动态控制流。回溯信息会沿着你的代码向上追溯,向你展示这种情况发生的位置。

静态控制流#

另一方面,支持所谓的静态控制流。静态控制流是指其值在不同调用之间不会改变的循环或 if 语句。通常,在 PyTorch 程序中,这种控制流产生于根据超参数决定模型架构的代码。作为一个具体的例子


import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # This if-statement is so-called static control flow.
        # Its condition does not depend on any input values
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

if 语句 if self.do_activation 不依赖于任何函数输入,因此它是静态的。 do_activation 可以被认为是一个超参数,具有该参数不同值的 MyModule 的不同实例的追踪具有不同的代码。这是一种有效的模式,受符号追踪支持。

许多动态控制流的实例在语义上是静态控制流。可以通过消除对输入值的数据依赖性来使这些实例支持符号追踪,例如将值移动到 Module 属性,或在符号追踪期间将具体值绑定到参数


def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

fx.symbolic_trace(f, concrete_args={'flag': True})

在真正动态控制流的情况下,包含此代码的程序部分可以被追踪为对方法(参见 使用 Tracer 类自定义追踪)或函数(参见 wrap())的调用,而不是追踪其内部。

torch 函数#

FX 使用 __torch_function__ 作为其拦截调用的机制(有关更多信息,请参阅 技术概述)。某些函数,例如 Python 内置函数或 math 模块中的函数,不在 __torch_function__ 覆盖范围内,但我们仍然希望在符号追踪中捕获它们。例如


import torch
import torch.fx
from math import sqrt

def normalize(x):
    """
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
    <...>
    File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
    File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

错误告诉我们内置函数 len 不受支持。我们可以通过 wrap() API 将此类函数作为直接调用记录在追踪中


torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

使用 Tracer 类自定义追踪#

Tracer 类是 symbolic_trace 实现的基础类。可以通过继承 Tracer 来自定义追踪行为,如下所示


class MyCustomTracer(torch.fx.Tracer):
    # Inside here you can override various methods
    # to customize tracing. See the `Tracer` API
    # reference
    pass


# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

叶子模块 (Leaf Modules)#

叶子模块是作为符号追踪中的调用出现的模块,而不是被追踪进去。默认的叶子模块集是标准 torch.nn 模块实例。例如


class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

可以通过重写 Tracer.is_leaf_module() 来自定义叶子模块集。

杂项#

  • 张量构造函数(例如 torch.zerostorch.onestorch.randtorch.randntorch.sparse_coo_tensor)目前不可追踪。

    • 确定性构造函数(zeros, ones)可以使用,它们产生的值将作为常量嵌入到追踪中。这仅在这些构造函数的参数引用动态输入大小时才会有问题。在这种情况下,ones_likezeros_like 可能是可行的替代方案。

    • 非确定性构造函数(rand, randn)将有一个单一的随机值嵌入在追踪中。这可能不是预期的行为。一种变通方法是将 torch.randn 封装在 torch.fx.wrap 函数中并调用该函数。

    
    @torch.fx.wrap
    def torch_randn(x, shape):
        return torch.randn(shape)
    
    def f(x):
        return x + torch_randn(x, 5)
    fx.symbolic_trace(f)
    
    • 此行为可能会在未来的版本中得到修正。

  • 类型注解

    • 支持 Python 3 风格的类型注解(例如 func(x : torch.Tensor, y : int) -> torch.Tensor),并且符号追踪将保留它们。

    • 目前不支持 Python 2 风格的注释类型注解 # type: (torch.Tensor, int) -> torch.Tensor

    • 目前不支持函数内局部变量的注解。

  • 关于 training 标志和子模块的注意事项

    • 在使用 torch.nn.functional.dropout 等 functional 时,通常会将 training 参数作为 self.training 传入。在 FX 追踪期间,这可能会被硬编码为一个常量值。

    
    import torch
    import torch.fx
    
    class DropoutRepro(torch.nn.Module):
        def forward(self, x):
        return torch.nn.functional.dropout(x, training=self.training)
    
    
    traced = torch.fx.symbolic_trace(DropoutRepro())
    print(traced.code)
    """
    def forward(self, x):
        dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
        return dropout
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    """
    AssertionError: Tensor-likes are not close!
    
    Mismatched elements: 15 / 15 (100.0%)
    Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
    Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
    """
    
    • 然而,当使用标准的 nn.Dropout() 子模块时,training 标志是被封装的,并且由于保留了 nn.Module 对象模型,它是可以改变的。

    
    class DropoutRepro2(torch.nn.Module):
        def __init__(self):
        super().__init__()
        self.drop = torch.nn.Dropout()
    
        def forward(self, x):
        return self.drop(x)
    
    traced = torch.fx.symbolic_trace(DropoutRepro2())
    print(traced.code)
    """
    def forward(self, x):
        drop = self.drop(x);  x = None
        return drop
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    
  • 由于这种差异,请考虑将动态与 training 标志交互的模块标记为叶子模块。

API参考#

torch.fx.symbolic_trace(root, concrete_args=None)[source]#

符号追踪 API

给定一个 nn.Module 或函数实例 root,该函数将返回一个 GraphModule,该模块是通过记录追踪 root 时看到的各操作构建而成的。

concrete_args 允许你对函数进行部分特化,无论是为了消除控制流还是数据结构。

例如

def f(a, b):
    if b == True:
        return a
    else:
        return a * 2

由于控制流的存在,FX 通常无法追踪到此内部。但是,我们可以使用 `concrete_args` 对 `b` 的值进行特化,以便追踪此代码

f = fx.symbolic_trace(f, concrete_args={"b": False})
assert f(3, False) == 6

请注意,尽管你仍然可以传入不同的 `b` 值,但它们将被忽略。

我们还可以使用 `concrete_args` 从函数中消除数据结构处理。这将使用 pytrees 展平你的输入。为了避免过度特化,请为不应特化的值传入 `fx.PH`。例如

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out


f = fx.symbolic_trace(
    f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}
)
assert f({"a": 1, "b": 2, "c": 4}) == 7
参数:
  • root (Union[torch.nn.Module, Callable]) – 要追踪并转换为 Graph 表示的模块或函数。

  • concrete_args (Optional[Dict[str, any]]) – 要进行部分特化的输入

返回:

根据从 root 记录的操作创建的模块。

返回类型:

GraphModule

注意

保证此 API 的向后兼容性。

torch.fx.wrap(fn_or_name)[source]#

可以在模块级作用域调用此函数,将 fn_or_name 注册为“叶子函数”。 “叶子函数”将在 FX 追踪中保留为 CallFunction 节点,而不是被追踪进去

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y


torch.fx.wrap("my_custom_function")


def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

该函数也可以等效地用作装饰器

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

包装好的函数可以被认为是一个“叶子函数”,类似于“叶子模块”的概念,也就是说,它们是在 FX 追踪中作为调用保留而不是被追踪进去的函数。

参数:

fn_or_name (Union[str, Callable]) – 在调用时要插入到图中的全局函数或其名称

注意

保证此 API 的向后兼容性。

class torch.fx.GraphModule(*args, **kwargs)[source]#

GraphModule 是从 fx.Graph 生成的 nn.Module。GraphModule 有一个 graph 属性,以及从该 graph 生成的 codeforward 属性。

警告

graph 被重新赋值时,codeforward 将自动重新生成。但是,如果你在没有重新赋值 graph 属性本身的情况下编辑了 graph 的内容,则必须调用 recompile() 来更新生成的代码。

注意

保证此 API 的向后兼容性。

__init__(root, graph, class_name='GraphModule')[source]#

构建一个 GraphModule。

参数:
  • root (Union[torch.nn.Module, Dict[str, Any]) – root 可以是 nn.Module 实例,也可以是映射字符串到任何属性类型的 Dict。在 root 是模块的情况下,Graph 节点 target 字段中对基于模块的对象的任何引用(通过限定名称)都将从 root 模块层级中的相应位置复制到 GraphModule 的模块层级中。在 root 是字典的情况下,节点 target 中的限定名称将直接在字典键中查找。字典映射的对象将被复制到 GraphModule 模块层级中的适当位置。

  • graph (Graph) – graph 包含此 GraphModule 应用于代码生成的节点

  • class_name (str) – name 表示此 GraphModule 的名称,用于调试。如果未设置,所有错误消息都将报告源自 GraphModule。将其设置为 root 的原始名称或在你的转换上下文中具有意义的名称可能会很有帮助。

注意

保证此 API 的向后兼容性。

add_submodule(target, m)[source]#

将给定的子模块添加到 self

如果它们是 target 的子路径且尚不存在,则安装空模块。

参数:
  • target (str) – 新子模块的完全限定字符串名称(有关如何指定完全限定字符串,请参阅 nn.Module.get_submodule 中的示例。)

  • m (Module) – 子模块本身;我们要安装在当前模块中的实际对象

返回:

子模块是否可以插入。

要使此方法返回 True,由 target 表示的链中的每个对象必须要么 a) 尚不存在,要么 b) 引用一个 nn.Module(不是参数或其他属性)

返回类型:

布尔值

注意

保证此 API 的向后兼容性。

property code: str#

返回从此 GraphModule 底层的 Graph 生成的 Python 代码。

delete_all_unused_submodules()[source]#

self 中删除所有未使用的子模块。

如果满足以下任一条件,则认为模块是“已使用”的:1. 它有被使用的子节点 2. 它的 forward 通过 call_module 节点直接被调用 3. 它有一个非模块属性通过 get_attr 节点被使用

可以调用此方法来清理 nn.Module,而无需手动在每个未使用的子模块上调用 delete_submodule

注意

保证此 API 的向后兼容性。

delete_submodule(target)[source]#

self 中删除给定的子模块。

如果 target 不是有效的目标,则不会删除该模块。

参数:

target (str) – 新子模块的完全限定字符串名称(有关如何指定完全限定字符串,请参阅 nn.Module.get_submodule 中的示例。)

返回:

目标字符串是否引用了

我们要删除的子模块。返回值 False 意味着 target 不是对子模块的有效引用。

返回类型:

布尔值

注意

保证此 API 的向后兼容性。

property graph: Graph#

返回此 GraphModule 底层的 Graph

print_readable(print_output=True, include_stride=False, include_device=False, colored=False, *, fast_sympy_print=False, expanded_def=False, additional_meta=None)[source]#

返回当前 GraphModule 及其子 GraphModule 生成的 Python 代码。

参数:

additional_meta (list[str] | None) – 要包含在输出中的元键的可选列表。对于列表中的每个键,如果它存在于 node.meta 中,则其值将以“键:值”的格式显示。例如:print_readable(additional_meta=[“seq_nr”])

警告

此 API 为实验性质,且保证向后兼容。

recompile()[source]#

从其 graph 属性重新编译此 GraphModule。编辑包含的 graph 后应调用此方法,否则此 GraphModule 生成的代码将过期。

注意

保证此 API 的向后兼容性。

返回类型:

PythonCode

to_folder(folder, module_name='FxModule')[source]#
将模块转储到具有 module_namefolder 中,以便它可以

使用 from <folder> import <module_name> 导入

参数 (Args)

folder (Union[str, os.PathLike]): 要导出代码的文件夹

module_name (str): 导出代码时用于 Module

顶级名称

警告

此 API 为实验性质,且保证向后兼容。

class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#

Graph 是 FX 中间表示中使用的主要数据结构。它由一系列 Node 组成,每个节点代表调用点(或其他语法结构)。这些 Node 列表加在一起构成一个有效的 Python 函数。

例如,以下代码

import torch
import torch.fx


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(
            torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3
        )


m = MyModule()
gm = torch.fx.symbolic_trace(m)

将生成以下 Graph

print(gm.graph)
graph(x):
    %linear_weight : [num_users=1] = self.linear.weight
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

有关 Graph 中表示的操作语义,请参阅 Node

注意

保证此 API 的向后兼容性。

__init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#

构建一个空 Graph。

注意

保证此 API 的向后兼容性。

call_function(the_function, args=None, kwargs=None, type_expr=None, name=None)[source]#

Graph 中插入一个 call_function Nodecall_function 节点表示对 Python 可调用对象的调用,由 the_function 指定。

参数:
  • the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 运算符、Python 函数,或 builtinsoperator 命名空间的成员。

  • args (Optional[Tuple[Argument, ...]]) – 要传递给被调用函数的正向参数。

  • kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用函数的关键字参数

  • type_expr (Optional[Any]) – 可选的类型注解,表示此节点输出将具有的 Python 类型。

  • name (Optional[str]) – 节点的名称。如果未指定,则设置为 None

返回:

新创建并插入的 call_function 节点。

返回类型:

Node

注意

此方法适用与 Graph.create_node() 相同的插入点和类型表达式规则。

注意

保证此 API 的向后兼容性。

call_method(method_name, args=None, kwargs=None, type_expr=None)[源码]#

Graph 中插入一个 call_method Node(节点)。一个 call_method 节点代表对 args 的第 0 个元素调用给定的方法。

参数:
  • method_name (str) – 要应用于 self 参数的方法名称。例如,如果 args[0] 是一个代表 TensorNode,那么要对该 Tensor 调用 relu(),请将 relu 传递给 method_name

  • args (Optional[Tuple[Argument, ...]]) – 传递给被调用方法的对应位置参数。请注意,这应该包含一个 self 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 传递给被调用方法的关键字参数

  • type_expr (Optional[Any]) – 可选的类型注解,表示此节点输出将具有的 Python 类型。

返回:

新创建并插入的 call_method 节点。

返回类型:

Node

注意

此方法适用与 Graph.create_node() 相同的插入点和类型表达式规则。

注意

保证此 API 的向后兼容性。

call_module(module_name, args=None, kwargs=None, type_expr=None)[源码]#

Graph 中插入一个 call_module Node。一个 call_module 节点代表对 Module 层级结构中某个 Module 的 forward() 函数的调用。

参数:
  • module_name (str) – 要调用的 ModuleModule 层级结构中的限定名。例如,如果被追踪的 Module 有一个名为 foo 的子模块,而 foo 又有一个名为 bar 的子模块,则应将限定名 foo.bar 作为 module_name 传递以调用该模块。

  • args (Optional[Tuple[Argument, ...]]) – 传递给被调用方法的对应位置参数。请注意,这不应包含 self 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 传递给被调用方法的关键字参数

  • type_expr (Optional[Any]) – 可选的类型注解,表示此节点输出将具有的 Python 类型。

返回:

新创建并插入的 call_module 节点。

返回类型:

Node

注意

此方法适用与 Graph.create_node() 相同的插入点和类型表达式规则。

注意

保证此 API 的向后兼容性。

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[源码]#

创建一个 Node 并将其添加到 Graph 的当前插入点。请注意,当前插入点可以通过 Graph.inserting_before()Graph.inserting_after() 进行设置。

参数:
  • op (str) – 此节点的 opcode(操作码)。其值为 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’ 之一。这些操作码的语义在 Graph 的文档字符串中有所描述。

  • args (Optional[Tuple[Argument, ...]]) – 此节点的参数元组。

  • kwargs (Optional[Dict[str, Argument]]) – 此节点的关键字参数

  • name (Optional[str]) – Node 的可选字符串名称。这将影响生成的 Python 代码中被赋值的变量名称。

  • type_expr (Optional[Any]) – 可选的类型注解,表示此节点输出将具有的 Python 类型。

返回:

新创建并插入的节点。

返回类型:

Node

注意

保证此 API 的向后兼容性。

eliminate_dead_code(is_impure_node=None)[源码]#

根据每个节点的使用者数量以及节点是否有副作用,从图中移除所有死代码。在调用此方法之前,图必须已经过拓扑排序。

参数:
  • is_impure_node (Optional[Callable[[Node], bool]]) – 一个返回以下内容的函数:

  • None (节点是否为非纯节点。如果此参数为) –

  • to (,则默认行为是) –

  • Node.is_impure. (使用) –

返回:

图是否因该 pass 而发生了改变。

返回类型:

布尔值

示例

在死代码消除之前,下面 a = x + 1 中的 a 没有使用者,因此可以从图中消除而不会产生影响。

def forward(self, x):
    a = x + 1
    return x + self.attr_1

在死代码消除之后,a = x + 1 已被移除,而 forward 的其余部分保留下来。

def forward(self, x):
    return x + self.attr_1

警告

死代码消除包含一些启发式方法,以避免移除有副作用的节点(参见 Node.is_impure),但通常其覆盖范围非常有限。因此,除非你确定 FX 图完全由函数式操作组成,或者你提供了自定义的用于检测副作用节点的函数,否则你不应该假定调用此方法是绝对稳妥的。

注意

保证此 API 的向后兼容性。

erase_node(to_erase)[源码]#

Graph 中删除一个 Node。如果该节点在 Graph 中仍有使用者,则抛出异常。

参数:

to_erase (Node) – 要从 Graph 中删除的 Node

注意

保证此 API 的向后兼容性。

find_nodes(*, op, target=None, sort=True)[源码]#

支持对节点进行快速查询

参数:
  • op (str) – 操作的名称

  • target (Optional[Target]) – 节点的目标。对于 call_function,必须提供 target。对于其他操作,target 是可选的。

  • sort (bool) – 是否按节点在图中出现的顺序返回节点。

返回:

具有请求的 op 和 target 的节点可迭代对象。

警告

此 API 为实验性质,且保证向后兼容。

get_attr(qualified_name, type_expr=None)[源码]#

在 Graph 中插入一个 get_attr 节点。一个 get_attr Node 代表从 Module 层级结构中获取一个属性。

参数:
  • qualified_name (str) – 要检索的属性的全限定名。例如,如果被追踪的 Module 有一个名为 foo 的子模块,该子模块有一个名为 bar 的子模块,而 bar 又有一个名为 baz 的属性,则应将 foo.bar.baz 作为 qualified_name 传递。

  • type_expr (Optional[Any]) – 可选的类型注解,表示此节点输出将具有的 Python 类型。

返回:

新创建并插入的 get_attr 节点。

返回类型:

Node

注意

此方法的插入点和类型表达式规则与 Graph.create_node 相同。

注意

保证此 API 的向后兼容性。

graph_copy(g, val_map, return_output_node=False)[源码]#

将给定图中的所有节点复制到 self 中。

参数:
  • g (Graph) – 要从中复制节点的源图。

  • val_map (Dict[Node, Node]) – 一个字典,将被填充为从 g 中的节点到 self 中的节点的映射。请注意,传入 val_map 时可以预先包含一些值,以覆盖某些特定值的复制行为。

返回:

如果 g 具有 output 节点,则返回 self 中现在等效于 g 中输出值的值;否则返回 None

返回类型:

元组[Argument, …] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None

注意

保证此 API 的向后兼容性。

inserting_after(n=None)[源码]#
设置 create_node 及其配套方法在图中插入的位置。

在 ‘with’ 语句中使用时,这将临时设置插入点,并在 with 语句退出时恢复它

with g.inserting_after(n):
    ...  # inserting after node n
...  # insert point restored to what it was previously
g.inserting_after(n)  #  set the insert point permanently

参数 (Args)

n (Optional[Node]): 在其之后插入的节点。如果为 None,则将插入到

整个图的最开始之后。

返回

一个资源管理器,将在 __exit__ 时恢复插入点。

注意

保证此 API 的向后兼容性。

inserting_before(n=None)[源码]#
设置 create_node 及其配套方法在图中插入的位置。

在 ‘with’ 语句中使用时,这将临时设置插入点,并在 with 语句退出时恢复它

with g.inserting_before(n):
    ...  # inserting before node n
...  # insert point restored to what it was previously
g.inserting_before(n)  #  set the insert point permanently

参数 (Args)

n (Optional[Node]): 在其之前插入的节点。如果为 None,则将插入到

整个图的最开始之后。

返回

一个资源管理器,将在 __exit__ 时恢复插入点。

注意

保证此 API 的向后兼容性。

lint()[源码]#

对该 Graph 运行各种检查,以确保其结构良好。特别是: - 检查 Node 是否拥有正确的所有权(归此图所有) - 检查 Node 是否按拓扑顺序出现 - 如果此 Graph 有所属的 GraphModule,则检查 target 是否存在于该 GraphModule 中

注意

保证此 API 的向后兼容性。

node_copy(node, arg_transform=<function Graph.<lambda>>)[源码]#

将一个节点从一个图复制到另一个图中。arg_transform 需要将 node 原属图中的参数转换为 self 所属图中的参数。示例:

# Copying all the nodes in `g` into `new_graph`
g: torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
参数:
  • node (Node) – 要复制到 self 中的节点。

  • arg_transform (Callable[[Node], Argument]) – 一个函数,用于将原节点的 argskwargs 中的 Node 参数转换为 self 中等效的参数。在最简单的情况下,它应该从一个将原图节点映射到 self 节点的表中检索值。

返回类型:

Node

注意

保证此 API 的向后兼容性。

property nodes: _node_list#

获取构成此 Graph 的 Node 列表。

请注意,此 Node 列表表示形式是一个双向链表。迭代期间的变动(例如删除节点、添加节点)是安全的。

返回:

一个双向链表形式的 Node。请注意,可以对此列表调用 reversed 以切换迭代顺序。

on_generate_code(make_transformer)[源码]#

在生成 Python 代码时注册一个转换器(transformer)函数

参数 (Args)
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc])

一个返回要注册的代码转换器的函数。此函数由 on_generate_code 调用以获取代码转换器。

此函数还会接收当前注册的代码转换器作为输入(如果未注册则为 None),以防不希望覆盖它。这对于将代码转换器链式连接在一起非常有用。

返回

一个上下文管理器,当在 with 语句中使用时,会自动恢复以前注册的代码转换器。

示例

gm: fx.GraphModule = ...


# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):
    return ["import pdb; pdb.set_trace()\n", *body]


# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(lambda _: insert_pdb)

# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(
    lambda current_trans: (
        lambda body: insert_pdb(
            current_trans(body) if current_trans else body
        )
    )
)

gm.recompile()
gm(*inputs)  # drops into pdb

此函数也可以用作上下文管理器,优点是会自动恢复以前注册的代码转换器

# ... continue from previous example

with gm.graph.on_generate_code(lambda _: insert_pdb):
    # do more stuff with `gm`...
    gm.recompile()
    gm(*inputs)  # drops into pdb

# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you
# run next `recompile()`).

警告

此 API 为实验性质,且保证向后兼容。

output(result, type_expr=None)[源码]#

Graph 中插入一个 output Node。一个 output 节点代表 Python 代码中的一个 return 语句。result 是应该返回的值。

参数:
  • result (Argument) – 要返回的值。

  • type_expr (Optional[Any]) – 可选的类型注解,表示此节点输出将具有的 Python 类型。

注意

此方法的插入点和类型表达式规则与 Graph.create_node 相同。

注意

保证此 API 的向后兼容性。

output_node()[源码]#

警告

此 API 为实验性质,且保证向后兼容。

返回类型:

Node

placeholder(name, type_expr=None, default_value)[源码]#

在 Graph 中插入一个 placeholder 节点。一个 placeholder 代表函数的一个输入。

参数:
  • name (str) – 输入值的名称。这对应于此 Graph 所代表函数的对应位置参数名称。

  • type_expr (Optional[Any]) – 一个可选的类型注解,代表此节点输出将具有的 Python 类型。在某些情况下,为了正确生成代码(例如,当该函数后续用于 TorchScript 编译时),这是必需的。

  • default_value (Any) – 此函数参数应采用的默认值。注意:为了允许将 None 作为默认值,应将 inspect.Signature.empty 作为此参数传递,以指明该参数_没有_默认值。

返回类型:

Node

注意

此方法的插入点和类型表达式规则与 Graph.create_node 相同。

注意

保证此 API 的向后兼容性。

print_tabular()[源码]#

以表格格式打印图的中间表示。注意,此 API 需要安装 tabulate 模块。

注意

保证此 API 的向后兼容性。

process_inputs(*args)[源码]#

处理参数,以便将它们传递给 FX 图。

警告

此 API 为实验性质,且保证向后兼容。

process_outputs(out)[源码]#

警告

此 API 为实验性质,且保证向后兼容。

python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False, expanded_def=False, record_func=False, additional_meta=None)[源码]#

将此 Graph 转换为有效的 Python 代码。

参数:

root_module (str) – 根模块的名称,用于查找限定名目标。通常为 ‘self’。

返回:

src: 代表该对象的 Python 源代码。globals: src 中全局名称与其引用对象之间的映射字典。

返回类型:

一个 PythonCode 对象,包含两个字段

注意

保证此 API 的向后兼容性。

set_codegen(codegen)[源码]#

警告

此 API 为实验性质,且保证向后兼容。

class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[源码]#

Node 是代表 Graph 中单个操作的数据结构。在大多数情况下,Node 代表对各种实体的调用点,如算子、方法和 Module(一些例外包括指定函数输入和输出的节点)。每个 Node 都有一个由其 op 属性指定的函数。每种 op 取值的 Node 语义如下:

  • placeholder 代表函数输入。name 属性指定该值将采用的名称。target 同样是参数的名称。args 包含以下内容之一:1) 无内容,或 2) 代表函数输入默认参数的单个参数。kwargs 不受关注。占位符对应于图打印输出中的函数参数(例如 x)。

  • get_attr 从模块层级结构中检索参数。name 同样是获取结果被赋值给的名称。target 是参数在模块层级结构中位置的全限定名。argskwargs 不受关注。

  • call_function 将自由函数应用于某些值。name 同样是要赋值的值的名称。target 是要应用的函数。argskwargs 代表函数的参数,遵循 Python 调用约定。

  • call_module 将模块层级结构中模块的 forward() 方法应用于给定参数。name 同前。target 是要调用的模块在模块层级结构中的全限定名。argskwargs 代表调用模块时所用的参数,不包括 self 参数

  • call_method 在一个值上调用方法。name 类似。target 是要应用于 self 参数的方法字符串名称。argskwargs 代表调用该方法时所用的参数,包括 self 参数

  • output 在其 args[0] 属性中包含被追踪函数的输出。这对应于 Graph 打印输出中的 “return” 语句。

注意

保证此 API 的向后兼容性。

property all_input_nodes: list[Node]#

返回作为此 Node 输入的所有 Node。这等同于遍历 argskwargs 并仅收集作为 Node 的值。

返回:

按顺序出现在此 Nodeargskwargs 中的 Nodes 列表。

append(x)[源码]#

在图的节点列表中将 x 插入到此节点之后。等同于 self.next.prepend(x)

参数:

x (Node) – 要放在此节点之后的节点。必须是同一个图的成员。

注意

保证此 API 的向后兼容性。

property args: tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None ...]#

Node 的参数元组。对参数的解释取决于节点的 opcode。更多信息请参见 Node 的文档字符串。

允许对此属性进行赋值。赋值时,所有对使用情况(uses)和使用者(users)的记账都会自动更新。

format_node(placeholder_names=None, maybe_return_typename=None, *, include_tensor_metadata=False)[源码]#

返回 self 的描述性字符串表示。

此方法可在不带参数的情况下作为调试实用程序使用。

此函数还在 Graph__str__ 方法内部使用。通过 placeholder_namesmaybe_return_typename 中的字符串,共同构成了该 Graph 所属 GraphModule 中自动生成的 forward 函数的签名。不应在其他情况下使用 placeholder_namesmaybe_return_typename

参数:
  • placeholder_names (list[str] | None) – 一个列表,用于存储代表生成的 forward 函数中占位符的格式化字符串。仅限内部使用。

  • maybe_return_typename (list[str] | None) – 一个单元素列表,用于存储代表生成的 forward 函数输出的格式化字符串。仅限内部使用。

  • include_tensor_metadata (bool) – 是否包含张量元数据

返回:

如果 1) 我们正在将 format_node 用作内部辅助工具

Graph__str__ 方法中,并且 2) self 是一个占位符 Node,则返回 None。否则,返回当前 Node 的描述性字符串表示。

返回类型:

str

注意

保证此 API 的向后兼容性。

insert_arg(idx, arg)[源码]#

将位置参数插入到给定索引的参数列表中。

参数:
  • idx (int) – 要在其之前插入的 self.args 中的元素索引。

  • arg (Argument) – 要插入到 args 中的新参数值

注意

保证此 API 的向后兼容性。

is_impure(impure_random=True)[源码]#

返回此操作是否是非纯的(impure),即它的 op 是否是占位符(placeholder)或输出(output),或者它是否是一个非纯的 call_function 或 call_module。

参数:

impure_random (bool) – 是否将随机操作视为非纯操作。

返回:

该操作是否为非纯操作。

返回类型:

布尔值

警告

此 API 为实验性质,且保证向后兼容。

property kwargs: dict[str tuple[Argument ...] | Sequence[Argument] | Mapping[str Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None]#

Node 的关键字参数字典。对参数的解释取决于节点的 opcode。更多信息请参见 Node 的文档字符串。

允许对此属性进行赋值。赋值时,所有对使用情况(uses)和使用者(users)的记账都会自动更新。

property next: Node#

返回节点链表中的下一个 Node

返回:

节点链表中的下一个 Node

normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[源码]#

返回 Python 目标的规范化参数。这意味着 args/kwargs 将与模块/函数的签名进行匹配,如果 normalize_to_only_use_kwargs 为真,则按位置顺序排他地返回 kwargs。此方法还会填充默认值。不支持仅限位置的参数(positional-only parameters)或可变长参数(varargs parameters)。

支持模块调用。

可能需要 arg_typeskwarg_types 以便消除重载的歧义。

参数:
  • root (torch.nn.Module) – 用于解析模块目标的模块。

  • arg_types (Optional[Tuple[Any]]) – 位置参数的类型元组

  • kwarg_types (Optional[Dict[str, Any]]) – 关键字参数的类型字典

  • normalize_to_only_use_kwargs (bool) – 是否规范化为仅使用关键字参数。

返回:

返回命名元组 ArgsKwargsPair,如果不成功则返回 None

返回类型:

ArgsKwargsPair | None

警告

此 API 为实验性质,且保证向后兼容。

prepend(x)[源码]#

在图的节点列表中将 x 插入到此节点之前。示例:

Before: p -> self
        bx -> x -> ax
After:  p -> x -> self
        bx -> ax
参数:

x (Node) – 要放在此节点之前的节点。必须是同一个图的成员。

注意

保证此 API 的向后兼容性。

property prev: Node#

返回节点链表中的上一个 Node

返回:

节点链表中的上一个 Node

replace_all_uses_with(replace_with, delete_user_cb=None, *, propagate_meta=False)[源码]#

在图中将所有对 self 的使用替换为节点 replace_with

参数:
  • replace_with (Node) – 用于替换所有对 self 使用的节点。

  • delete_user_cb (Callable) – 被调用的回调函数,用于确定是否应移除 self 节点的给定使用者。

  • propagate_meta (bool) – 是否将原始节点 .meta 字段上的所有属性复制到替换节点上。为了安全起见,仅当替换节点尚不存在 .meta 字段时,执行此操作才有效。

返回:

进行了此更改的 Node 列表。

返回类型:

list[Node]

注意

保证此 API 的向后兼容性。

replace_input_with(old_input, new_input)[源码]#

遍历 self 的输入节点,并将所有 old_input 实例替换为 new_input

参数:
  • old_input (Node) – 要被替换的旧输入节点。

  • new_input (Node) – 用于替换 old_input 的新输入节点。

注意

保证此 API 的向后兼容性。

property stack_trace: str | None#

返回在追踪期间记录的 Python 调用栈轨迹(如果有)。当使用 fx.Tracer 追踪时,此属性通常由 Tracer.create_proxy 填充。要在追踪期间出于调试目的记录栈轨迹,请在 Tracer 实例上设置 record_stack_traces = True。当使用 dynamo 追踪时,此属性默认由 OutputGraph.create_proxy 填充。

stack_trace 的字符串末尾将是其最内层的调用帧。

update_arg(idx, arg)[源码]#

更新现有的位置参数以包含新值 arg。调用后,self.args[idx] == arg

参数:
  • idx (int) – 要更新的元素在 self.args 中的索引

  • arg (Argument) – 要写入 args 的新参数值

注意

保证此 API 的向后兼容性。

update_kwarg(key, arg)[源码]#

更新现有的关键字参数以包含新值 arg。调用后,self.kwargs[key] == arg

参数:
  • key (str) – 要更新的元素在 self.kwargs 中的键

  • arg (Argument) – 要写入 kwargs 的新参数值

注意

保证此 API 的向后兼容性。

class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[源码]#

Tracer 是实现 torch.fx.symbolic_trace 符号追踪功能的类。调用 symbolic_trace(m) 等同于 Tracer().trace(m)

可以对 Tracer 进行子类化,以覆盖追踪过程中的各种行为。可以覆盖的不同行为在该类的各方法文档字符串中进行了描述。

注意

保证此 API 的向后兼容性。

call_module(m, forward, args, kwargs)[源码]#

该方法指定了当 Tracer 遇到对 nn.Module 实例的调用时的行为。

默认行为是通过 is_leaf_module 检查被调用的模块是否为叶子模块。如果是,则在 Graph 中发射(emit)一个指向 mcall_module 节点。否则,正常调用该 Module,并追踪其 forward 函数中的操作。

可以覆盖此方法以实现其他行为——例如,创建嵌套的追踪 GraphModules,或在跨越 Module 边界追踪时所需的任何其他行为。

参数:
  • m (Module) – 正在发射调用节点的模块

  • forward (Callable) – 要被调用的 Module 的 forward() 方法

  • args (Tuple) – 模块调用点的位置参数

  • kwargs (Dict) – 模块调用点的关键字参数

返回:

Module 调用的返回值。如果发出了 call_module 节点,则这是一个 Proxy 值。否则,它是 Module 调用返回的任何值。

返回类型:

任何

注意

保证此 API 的向后兼容性。

create_arg(a)[source]#

一种指定在准备用作 Graph 中节点参数的值时的追踪(tracing)行为的方法。

默认情况下,行为包括:

  1. 迭代集合类型(例如 tuple、list、dict)并对其中的元素递归调用 create_args

  2. 给定一个 Proxy 对象,返回对底层 IR Node 的引用。

  3. 给定一个非 Proxy Tensor 对象,针对各种情况发出 IR。

    • 对于 Parameter,发出一个引用该 Parameter 的 get_attr 节点。

    • 对于非 Parameter Tensor,将该 Tensor 存储在一个引用该属性的特殊属性中。

可以重写此方法以支持更多类型。

参数:

a (Any) – 将在 Graph 中作为 Argument 发出的值。

返回:

转换为适当 Argument 的值 a

返回类型:

Argument

注意

保证此 API 的向后兼容性。

create_args_for_root(root_fn, is_module, concrete_args=None)[source]#

创建与 root Module 签名相对应的 placeholder 节点。此方法会内省 root 的签名并相应地发出这些节点,同时也支持 *args**kwargs

警告

此 API 为实验性质,且保证向后兼容。

create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]#

根据给定的 target、args、kwargs 和 name 插入一个图节点。

可以重写此方法,以对节点创建中使用的值进行额外的检查、验证或修改。例如,用户可能希望禁止记录就地(in-place)操作。

注意

保证此 API 的向后兼容性。

返回类型:

Node

create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]#

根据给定参数创建一个 Node,然后返回封装在 Proxy 对象中的该 Node。

如果 kind = 'placeholder',那么我们正在创建一个代表函数参数的 Node。如果我们需要编码默认参数,我们使用 args 元组。对于其他的 placeholder Node,args 为空。

注意

保证此 API 的向后兼容性。

get_fresh_qualname(prefix)[source]#

获取前缀的新名称并返回。此函数确保它不会与图上的现有属性冲突。

注意

保证此 API 的向后兼容性。

返回类型:

str

getattr(attr, attr_val, parameter_proxy_cache)[source]#

当我们在调用 nn.Module 实例时调用 getattr,该方法指定此 Tracer 的行为。

默认情况下,其行为是返回该属性的 proxy 值。它还会将 proxy 值存储在 parameter_proxy_cache 中,以便后续调用重用该 proxy,而不是创建新的 proxy。

可以重写此方法,例如在查询参数时不返回 proxy。

参数:
  • attr (str) – 正在查询的属性名称

  • attr_val (Any) – 属性的值

  • parameter_proxy_cache (Dict[str, Any]) – 属性名称到 proxy 的映射缓存

返回:

getattr 调用的返回值。

警告

此 API 为实验性质,且保证向后兼容。

is_leaf_module(m, module_qualified_name)[source]#

一种指定给定 nn.Module 是否为“叶子”模块的方法。

叶子模块是出现在 IR 中的原子单元,由 call_module 调用引用。默认情况下,PyTorch 标准库命名空间 (torch.nn) 中的模块都是叶子模块。除非通过此参数另行指定,否则所有其他模块都会被追踪并记录其组成的算子。

参数:
  • m (Module) – 正在查询的模块

  • module_qualified_name (str) – 此模块到 root 的路径。例如,如果您有一个模块层次结构,其中子模块 foo 包含子模块 bar,而 bar 又包含子模块 baz,则该模块在此处将以限定名称 foo.bar.baz 出现。

返回类型:

布尔值

注意

保证此 API 的向后兼容性。

iter(obj)[source]#
当正在对 proxy 对象进行迭代时调用,例如:

在控制流中使用时。通常我们不知道该怎么做,因为我们不知道 proxy 的值,但自定义 tracer 可以使用 create_node 向图节点附加更多信息,并可以选择返回一个迭代器。

注意

保证此 API 的向后兼容性。

返回类型:

迭代器 (Iterator)

keys(obj)[source]#
当调用 proxy 对象的 keys() 方法时被调用。

这是在 proxy 上调用 ** 时发生的情况。如果 ** 应该在您的自定义 tracer 中工作,这应该返回一个迭代器。

注意

保证此 API 的向后兼容性。

返回类型:

任何

path_of_module(mod)[source]#

root 的模块层级结构中查找 mod 限定名称的辅助方法。例如,如果 root 有一个名为 foo 的子模块,而 foo 又有一个名为 bar 的子模块,将 bar 传入此函数将返回字符串 "foo.bar"。

参数:

mod (str) – 要检索其限定名称的 Module

返回类型:

str

注意

保证此 API 的向后兼容性。

proxy(node)[source]#

注意

保证此 API 的向后兼容性。

返回类型:

Proxy

to_bool(obj)[source]#
当 proxy 对象正在被转换为布尔值时调用,例如:

在控制流中使用时。通常我们不知道该怎么做,因为我们不知道 proxy 的值,但自定义 tracer 可以使用 create_node 向图节点附加更多信息,并可以选择返回一个值。

注意

保证此 API 的向后兼容性。

返回类型:

布尔值

trace(root, concrete_args=None)[source]#

追踪 root 并返回相应的 FX Graph 表示。root 既可以是 nn.Module 实例,也可以是 Python 可调用对象。

请注意,在此调用之后,self.root 可能与此处传入的 root 不同。例如,当一个自由函数被传递给 trace() 时,我们将创建一个 nn.Module 实例作为 root,并在其中添加嵌入常量。

参数:
  • root (Union[Module, Callable]) – 要追踪的模块或函数。保证此参数的向后兼容性。

  • concrete_args (Optional[Dict[str, any]]) – 不应被视为 Proxy 的具体参数。此参数是实验性的,其向后兼容性 保证。

返回:

表示传入的 root 语义的 Graph

返回类型:

Graph

注意

保证此 API 的向后兼容性。

class torch.fx.Proxy(node, tracer=None)[source]#

Proxy 对象是 Node 包装器,它们在符号追踪(symbolic tracing)过程中流经程序,并记录它们触及的所有操作(torch 函数调用、方法调用、算子)到不断增长的 FX Graph 中。

如果您正在进行图转换,可以将您自己的 Proxy 方法包装在原始 Node 周围,以便可以使用重载运算符向 Graph 添加额外内容。

Proxy 对象无法被迭代。换句话说,如果在循环中或作为 *args/**kwargs 函数参数使用 Proxy,符号追踪器将抛出错误。

解决此问题主要有两种方法:1. 将不可追踪的逻辑提取到顶级函数中,并在其上使用 fx.wrap。2. 如果控制流是静态的(即循环行程次数基于某些超参数),代码可以保留在其原始位置,并重构为类似如下的形式:

for i in range(self.some_hyperparameter):
    indexed_item = proxied_value[i]

有关 Proxy 内部机制的更详细说明,请查看 torch/fx/README.md 中的“Proxy”部分。

注意

保证此 API 的向后兼容性。

class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[source]#

Interpreter(解释器)逐节点执行 FX 图。此模式对许多事情都很有用,包括编写代码转换以及分析传递(analysis passes)。

Interpreter 类中的方法可以被重写以自定义执行行为。按调用层级划分的可重写方法映射如下:

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

示例

假设我们要将所有 torch.neg 实例换成 torch.sigmoid,反之亦然(包括它们的 Tensor 等效方法)。我们可以像这样子类化 Interpreter:

class NegSigmSwapInterpreter(Interpreter):
    def call_function(
        self, target: Target, args: Tuple, kwargs: Dict
    ) -> Any:
        if target is torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(target, args, kwargs)

    def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
        if target == "neg":
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(target, args, kwargs)


def fn(x):
    return torch.sigmoid(x).neg()


gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
参数:
  • module (torch.nn.Module) – 要执行的模块

  • garbage_collect_values (bool) – 是否在模块执行过程中最后一次使用值后将其删除。这确保了执行期间的最佳内存使用。可以禁用此功能,例如,通过查看 Interpreter.env 属性来检查执行中的所有中间值。

  • graph (Optional[Graph]) – 如果传递了此参数,解释器将执行此图而不是 module.graph,并使用提供的 module 参数来满足任何状态请求。

注意

保证此 API 的向后兼容性。

boxed_run(args_list)[source]#

通过解释方式运行 module 并返回结果。这使用了“boxed”调用约定,即您传递一个参数列表,该列表将被解释器清除。这确保了输入张量被及时释放。

注意

保证此 API 的向后兼容性。

call_function(target, args, kwargs)[source]#

执行 call_function 节点并返回结果。

参数:
  • target (Target) – 此节点的调用目标。有关语义详情,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型:

任何

Return

Any:函数调用返回的值

注意

保证此 API 的向后兼容性。

call_method(target, args, kwargs)[source]#

执行 call_method 节点并返回结果。

参数:
  • target (Target) – 此节点的调用目标。有关语义详情,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型:

任何

Return

Any:方法调用返回的值

注意

保证此 API 的向后兼容性。

call_module(target, args, kwargs)[source]#

执行 call_module 节点并返回结果。

参数:
  • target (Target) – 此节点的调用目标。有关语义详情,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型:

任何

Return

Any:模块调用返回的值

注意

保证此 API 的向后兼容性。

fetch_args_kwargs_from_env(n)[source]#

从当前执行环境中获取节点 nargskwargs 的具体值。

参数:

n (Node) – 应当获取其 argskwargs 的节点。

返回:

具有节点 n 具体值的 argskwargs

返回类型:

Tuple[Tuple, Dict]

注意

保证此 API 的向后兼容性。

fetch_attr(target)[source]#

self.module 的模块层级结构中获取一个属性。

参数:

target (str) – 要获取的属性的全限定名称

返回:

该属性的值。

返回类型:

任何

注意

保证此 API 的向后兼容性。

get_attr(target, args, kwargs)[source]#

执行 get_attr 节点。将从 self.module 的模块层级结构中检索一个属性值。

参数:
  • target (Target) – 此节点的调用目标。有关语义详情,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回:

检索到的属性值

返回类型:

任何

注意

保证此 API 的向后兼容性。

map_nodes_to_values(args, n)[source]#

递归遍历 args,并在当前执行环境中查找每个 Node 的具体值。

参数:
  • args (Argument) – 在其中查找具体值的数据结构

  • n (Node) – args 所属的节点。这仅用于错误报告。

返回类型:

元组[Argument, …] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None

注意

保证此 API 的向后兼容性。

output(target, args, kwargs)[source]#

执行 output 节点。这实际上只是检索 output 节点所引用的值并将其返回。

参数:
  • target (Target) – 此节点的调用目标。有关语义详情,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回:

由输出节点引用的返回值

返回类型:

任何

注意

保证此 API 的向后兼容性。

placeholder(target, args, kwargs)[source]#

执行 placeholder 节点。请注意,这是有状态的:Interpreter 维护一个关于传给 run 的参数的内部迭代器,并且此方法返回该迭代器上的 next()。

参数:
  • target (Target) – 此节点的调用目标。有关语义详情,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回:

检索到的参数值。

返回类型:

任何

注意

保证此 API 的向后兼容性。

run(*args, initial_env=None, enable_io_processing=True)[source]#

通过解释运行 module 并返回结果。

参数:
  • *args – 要运行的 Module 的参数,按位置顺序排列

  • initial_env (Optional[Dict[Node, Any]]) – 一个可选的初始执行环境。这是一个将 Node 映射到任何值的字典。例如,这可以用于为某些 Node 预先填充结果,从而在解释器内仅进行部分评估。

  • enable_io_processing (bool) – 如果为 true,我们在使用输入和输出之前,先用图的 process_inputs 和 process_outputs 函数对其进行处理。

返回:

执行 Module 后返回的值

返回类型:

任何

注意

保证此 API 的向后兼容性。

run_node(n)[source]#

运行特定的节点 n 并返回结果。根据 node.op 调用 placeholder、get_attr、call_function、call_method、call_module 或 output。

参数:

n (Node) – 要执行的节点

返回:

执行 n 的结果

返回类型:

任何

注意

保证此 API 的向后兼容性。

class torch.fx.Transformer(module)[source]#

Transformer 是一种特殊类型的解释器,它会产生一个新的 Module。它公开了一个 transform() 方法,该方法返回转换后的 ModuleTransformer 不需要像 Interpreter 那样提供运行参数。Transformer 完全以符号化方式工作。

示例

假设我们要将所有 torch.neg 实例换成 torch.sigmoid,反之亦然(包括它们的 Tensor 等效方法)。我们可以像这样子类化 Transformer

class NegSigmSwapXformer(Transformer):
    def call_function(
        self,
        target: "Target",
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Any],
    ) -> Any:
        if target is torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(target, args, kwargs)

    def call_method(
        self,
        target: "Target",
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Any],
    ) -> Any:
        if target == "neg":
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(target, args, kwargs)


def fn(x):
    return torch.sigmoid(x).neg()


gm = torch.fx.symbolic_trace(fn)

transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
参数:

module (GraphModule) – 要转换的 Module

注意

保证此 API 的向后兼容性。

call_function(target, args, kwargs)[source]#

注意

保证此 API 的向后兼容性。

返回类型:

任何

call_module(target, args, kwargs)[source]#

注意

保证此 API 的向后兼容性。

返回类型:

任何

get_attr(target, args, kwargs)[source]#

执行 get_attr 节点。在 Transformer 中,此方法被重写,以便在输出图中插入一个新的 get_attr 节点。

参数:
  • target (Target) – 此节点的调用目标。有关语义详情,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型:

Proxy

注意

保证此 API 的向后兼容性。

placeholder(target, args, kwargs)[source]#

执行 placeholder 节点。在 Transformer 中,此方法被重写,以便在输出图中插入一个新的 placeholder

参数:
  • target (Target) – 此节点的调用目标。有关语义详情,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型:

Proxy

注意

保证此 API 的向后兼容性。

transform()[source]#

转换 self.module 并返回转换后的 GraphModule

注意

保证此 API 的向后兼容性。

返回类型:

GraphModule

torch.fx.replace_pattern(gm, pattern, replacement)[source]#

在 GraphModule (gm) 的图中匹配所有可能的非重叠算子集及其数据依赖项 (pattern),然后将每个匹配的子图替换为另一个子图 (replacement)。

参数:
返回:

一个 Match 对象列表,表示原始图中与 pattern 匹配的位置。如果没有匹配项,列表为空。Match 定义为:

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]

返回类型:

List[Match]

示例

import torch
from torch.fx import symbolic_trace, subgraph_rewriter


class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)


def pattern(w1, w2):
    return torch.cat([w1, w2])


def replacement(w1, w2):
    return torch.stack([w1, w2])


traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上述代码将首先在 traced_moduleforward 方法中匹配 pattern。模式匹配是基于使用-定义(use-def)关系完成的,而不是节点名称。例如,如果你在 pattern 中有 p = torch.cat([a, b]),你可以在原始的 forward 函数中匹配 m = torch.cat([a, b]),尽管变量名不同(pm)。

pattern 中的 return 语句仅根据其值进行匹配;它可能匹配也可能不匹配较大图中的 return 语句。换句话说,模式不必延伸到较大图的末尾。

当模式匹配成功时,它将从较大函数中删除并由 replacement 替换。如果较大函数中有多个 pattern 匹配项,则每个非重叠匹配项都将被替换。在匹配重叠的情况下,将替换重叠匹配项集中找到的第一个匹配项。(此处的“第一个”定义为节点使用-定义关系的拓扑排序中的第一个。在大多数情况下,第一个节点是直接出现在 self 之后的参数,而最后一个节点是函数返回的任何内容。)

需要注意的一件重要事情是,pattern 可调用对象的参数必须在可调用对象本身中使用,并且 replacement 可调用对象的参数必须与 pattern 匹配。第一条规则解释了为什么在上面的代码块中,forward 函数有参数 x, w1, w2,但 pattern 函数只有参数 w1, w2pattern 不使用 x,因此它不应将 x 指定为参数。作为第二条规则的示例,考虑替换:

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

替换

def replacement(x, y):
    return torch.relu(x)

在这种情况下,replacement 需要与 pattern 相同数量的参数(包括 xy),即使参数 yreplacement 中并未使用。

调用 subgraph_rewriter.replace_pattern 后,生成的 Python 代码如下所示:

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

注意

保证此 API 的向后兼容性。

torch.fx.traceback.annotate(annotation_dict)[source]#

暂时向当前追踪上下文添加自定义注释。从此追踪上下文产生的 fx_node 将在 node.metadata["custom"] 字段中包含自定义注释。

此上下文管理器允许您通过更新全局 current_meta["custom"] 字典,将任意元数据插入 PT2 追踪系统。注释在上下文退出后自动还原。

梯度累积节点将不会被注释。

这旨在供高级用户使用,他们需要在导出追踪期间向 fx 节点附加额外的元数据(例如,用于调试、分析或外部工具)。

注意

此 API 不具有向后兼容性,并且可能在未来版本中发生演变。

注意

此 API 与 fx.symbolic_trace 或 jit.trace 不兼容。它旨在与 PT2 系列追踪器(例如 torch.export 和 dynamo)配合使用。

参数:

annotation_dict (dict) – 要注入到 FX 追踪元数据中的自定义键值对字典。

示例

退出上下文后,自定义注释将被移除。

>>> with annotate({"source": "custom_pass", "tag": 42}):
...     pass  # Your computation here

警告

此 API 为实验性质,且保证向后兼容。

torch.fx.passes.tools_common.stable_topological_sort(gm)[source]#

将给定 GraphModule 的图替换为一个包含与原始图相同节点、但按拓扑排序排列的图,同时尽可能保留原始节点顺序。

此函数执行稳定拓扑排序,其中节点出现的顺序:1. 遵循数据依赖关系(拓扑排序)2. 在没有依赖约束时保留原始节点顺序。

该算法使用带优先级队列的 Kahn 算法:所有依赖项均已满足的节点将被添加到最小堆中,并按其原始位置排序。这确保了我们在就绪节点中始终处理原始顺序中最靠前的节点。

参数:

gm (GraphModule) – 要进行拓扑排序的图模块。它是就地(in-place)修改的。

返回:

就地排序后的图模块

返回类型:

GraphModule

警告

此 API 为实验性质,且保证向后兼容。

torch.fx.passes.split_utils.move_non_tensor_nodes_on_boundary(subgraphs)[source]#

移动子图边界上的非张量(non-tensor)节点。

对于每个子图:

  1. 查找类型不是张量且其任何子节点位于另一个子图中的节点,将它们放入队列中以进行下一步

  2. 对队列中的那些节点执行 BFS,并对每个节点(假设节点为 X 且位于子图 A 中)运行 DFS

    1. 如果在 to_subgraph 中,则返回(继续 DFS)

    2. 如果在 from_subgraph 中,则将这些节点收集到 nodes_to_move 中,并继续 DFS

    3. 否则,这意味着它无法被移动

    4. 还要检查节点 X 的父节点是否应该放入队列中。(队列中可能有重复节点,只需处理一次该节点)

参数:

subgraphs – 包含待处理节点的子图列表

警告

此 API 为实验性质,且保证向后兼容。