评价此页

torch.fx#

创建于: 2020年12月15日 | 最后更新于: 2025年07月15日

概述#

FX 是一个供开发者用来转换 nn.Module 实例的工具包。FX 包含三个主要组件:符号追踪器中间表示Python 代码生成。这些组件协同工作的演示

import torch


# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)


module = MyModule()

from torch.fx import symbolic_trace

# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号追踪器 对 Python 代码执行“符号执行”。它通过代码传递称为代理(Proxies)的假值。对这些代理的操作会被记录下来。有关符号追踪的更多信息可以在 symbolic_trace()Tracer 文档中找到。

中间表示 是用于存放符号追踪期间记录的操作的容器。它包含一系列节点(Nodes),这些节点代表函数输入、调用点(指向函数、方法或 torch.nn.Module 实例)以及返回值。有关 IR 的更多信息可以在 Graph 文档中找到。IR 是应用转换的格式。

Python 代码生成 是 FX 成为 Python 到 Python(或模块到模块)转换工具包的原因。对于每个 Graph IR,我们可以生成与 Graph 语义相匹配的有效 Python 代码。此功能封装在 GraphModule 中,它是一个 torch.nn.Module 实例,其中包含一个 Graph 以及从 Graph 生成的 forward 方法。

总而言之,这个组件管道(符号追踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。此外,这些组件也可以单独使用。例如,符号追踪可以单独用于捕获代码的一种形式以便进行分析(而非转换)。代码生成可用于以编程方式生成模型,例如从配置文件生成。FX 有许多用途!

可以在 examples 存储库中找到几个转换示例。

编写转换#

什么是 FX 转换?本质上,它是一个如下所示的函数。


import torch
import torch.fx

def transform(m: nn.Module,
                tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: Modify this Graph or create a new one
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

您的转换将接收一个 torch.nn.Module,从中获取一个 Graph,进行一些修改,然后返回一个新的 torch.nn.Module。您应该将 FX 转换返回的 torch.nn.Module 视为与常规 torch.nn.Module 相同 —— 您可以将其传递给另一个 FX 转换,或者运行它。确保您的 FX 转换的输入和输出是 torch.nn.Module 将允许组合。

注意

也可以修改现有的 GraphModule 而不是创建新的,如下所示:

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # Modify gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

请注意,您必须调用 GraphModule.recompile() 来使 GraphModule 上生成的 forward() 方法与修改后的 Graph 同步。

鉴于您已经传入一个已追踪成 Graphtorch.nn.Module,现在有两种主要方法可以用来构建一个新的 Graph

快速图入门#

图语义的完整介绍可以在 Graph 文档中找到,但我们在这里将介绍基础知识。 Graph 是一个表示 GraphModule 中方法的_数据结构。它所需的信息是:

  • 方法有哪些输入?

  • 方法内部运行的操作是什么?

  • 方法的输出(即返回值)是什么?

所有这三个概念都用 Node 实例表示。让我们用一个简单的例子来看看这是什么意思:


import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

这里我们定义了一个模块 MyModule 用于演示,实例化它,符号追踪它,然后调用 Graph.print_tabular() 方法来打印一个表格,显示这个 Graph 的节点。

操作码

名称

目标

参数

关键字参数

占位符

x

x

()

{}

get_attr

linear_weight

linear.weight

()

{}

call_function

add_1

(x, linear_weight)

{}

call_module

linear_1

linear

(add_1,)

{}

call_method

relu_1

relu

(linear_1,)

{}

call_function

sum_1

<内置方法 sum …>

(relu_1,)

{‘dim’: -1}

call_function

topk_1

<内置方法 topk …>

(sum_1, 3)

{}

output

output

output

(topk_1,)

{}

我们可以利用这些信息来回答上面提出的问题。

  • 方法有哪些输入?在 FX 中,方法输入通过特殊的 placeholder 节点指定。在这种情况下,我们有一个带有 targetxplaceholder 节点,这意味着我们有一个名为 x 的单个(非 self)参数。

  • 方法内部的操作是什么?get_attrcall_functioncall_modulecall_method 节点代表方法中的操作。所有这些操作的完整语义可以在 Node 文档中找到。

  • 方法的返回值是什么?Graph 中的返回值由特殊的 output 节点指定。

鉴于我们现在了解了 FX 中代码表示的基本知识,我们现在可以探索如何编辑 Graph

图操作#

直接图操作#

构建新 Graph 的一种方法是直接操作旧的图。为此,我们可以简单地获取从符号追踪获得的 Graph 并对其进行修改。例如,假设我们希望将 torch.add() 调用替换为 torch.mul() 调用。


import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
                tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                    # Graph is well-formed.

    return fx.GraphModule(m, graph)

我们还可以进行更复杂的 Graph 重写,例如删除或追加节点。为了辅助这些转换,FX 在 Graph 文档中提供了用于转换图的实用函数。下面有一个使用这些 API 追加 torch.relu() 调用的示例。


# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

对于仅包含替换的简单转换,您还可以使用 子图重写器

使用 replace_pattern() 进行子图重写#

FX 还提供了比直接图操作更高级别的自动化。 replace_pattern() API 基本上是一个用于编辑 Graph 的“查找/替换”工具。它允许您指定一个 pattern 和一个 replacement 函数,它将追踪这些函数,在 pattern 图中找到操作组的实例,然后用 replacement 图的副本替换这些实例。这有助于极大地自动化繁琐的图操作代码,因为随着转换变得更复杂,代码可能会变得笨拙。

代理/重追踪#

操作 Graph 的另一种方法是重用符号追踪中使用的 Proxy 机制。例如,假设我们想编写一个将 PyTorch 函数分解为更小操作的转换。它会将每个 F.relu(x) 调用转换为 (x > 0) * x。一种可能性是执行所需图重写,在 F.relu 之后插入比较和乘法,然后清理原始 F.relu。但是,我们可以通过使用 Proxy 对象自动将操作记录到 Graph 中来自动化此过程。

使用此方法,我们将要插入的操作编写为常规 PyTorch 代码,并使用 Proxy 对象作为参数来调用该代码。这些 Proxy 对象将捕获对它们的执行的操作,并将它们追加到 Graph


# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
                tracer_class : type = fx.Tracer) -> torch.nn.Module:
    """
    Decompose `model` into smaller constituent operations.
    Currently,this only supports decomposing ReLU into its
    mathematical definition: (x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # By wrapping the arguments with proxies,
            # we can dispatch to the appropriate
            # decomposition rule and implicitly add it
            # to the Graph by symbolically tracing it.
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # Operations on `Proxy` always yield new `Proxy`s, and the
            # return value of our decomposition rule is no exception.
            # We need to extract the underlying `Node` from the `Proxy`
            # to use it in subsequent iterations of this transform.
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # Default case: we don't have a decomposition rule for this
            # node, so just copy the node over into the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式图操作外,使用 Proxy 还允许您将重写规则指定为本机 Python 代码。对于需要大量重写规则的转换(例如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。请注意,在调用 Proxy 时,我们还传递了一个指向底层变量 graph 的追踪器。这是为了以防万一操作是 n 元的(例如,add 是二元运算符),对 Proxy 的调用不会创建多个图追踪器实例,这可能导致意外的运行时错误。我们尤其推荐这种使用 Proxy 的方法,因为底层运算符不能安全地假定为一元运算符。

使用 Proxy 进行 Graph 操作的完整示例可以在 这里 找到。

解释器模式#

FX 中一个有用的代码组织模式是遍历 Graph 中的所有 Node 并执行它们。这可用于多种用途,包括对流经图的值进行运行时分析,或通过使用 Proxy 进行重追踪来转换代码。例如,假设我们想运行一个 GraphModule,并在运行时看到 torch.Tensor 的形状和 dtype 属性时记录它们。这可能看起来像:


import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':

                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

如您所见,FX 的完整解释器并不复杂,但可能非常有用。为了方便使用这种模式,我们提供了 Interpreter 类,它包含了上述逻辑,并通过方法覆盖来覆盖解释器执行的某些方面。

除了执行操作之外,我们还可以通过将 Proxy 值传递给解释器来生成新的 Graph。类似地,我们提供了 Transformer 类来包含此模式。 Transformer 的行为类似于 Interpreter,但不是调用 run 方法来获取模块的具体输出值,而是调用 Transformer.transform() 方法来返回一个新 GraphModule,该 GraphModule 应用了您作为覆盖方法安装的任何转换规则。

解释器模式示例#

调试#

引言#

在编写转换的过程中,我们的代码可能并不总是正确的。在这种情况下,我们可能需要进行一些调试。关键是反向工作:首先,检查调用生成模块的结果以证明或证伪正确性。然后,检查并调试生成的代码。然后,调试导致生成代码的转换过程。

如果您不熟悉调试器,请参阅辅助部分 可用调试器

转换创作中的常见陷阱#

  • 不确定的 set 迭代顺序。在 Python 中,set 数据类型是无序的。使用 set 来包含诸如 Node 等对象的集合,可能会导致意外的非确定性。一个例子是迭代一组 Node 以将它们插入 Graph。由于 set 数据类型是无序的,输出程序中操作的顺序将是非确定性的,并且可能在程序调用之间发生变化。推荐的替代方法是使用 dict 数据类型,它从 Python 3.7 开始(以及从 cpython 3.6 开始)是按插入顺序排序的。通过存储在 dict 的键中的值,可以等效地使用 dict 来实现集合去重。

检查模块的正确性#

由于大多数深度学习模块的输出是浮点 torch.Tensor 实例,因此检查两个 torch.nn.Module 的结果是否等效不像执行简单的相等检查那样直接。为了说明这一点,让我们举一个例子:


import torch
import torch.fx
import torchvision.models as models

def transform(m : torch.nn.Module) -> torch.nn.Module:
    gm = torch.fx.symbolic_trace(m)

    # Imagine we're doing some transforms here
    # <...>

    gm.recompile()

    return gm

resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)

input_image = torch.randn(5, 3, 224, 224)

assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""

在这里,我们尝试使用 == 相等运算符检查两个深度学习模型的输出值是否相等。然而,这既不明
确,因为该运算符返回的是张量而不是布尔值,而且由于浮点值的比较应该使用误差范围(或 epsilon)来考虑浮点运算的非交换性(更多详细信息请参阅此处)。我们可以改用 torch.allclose(),它将提供一个近似比较,同时考虑相对和绝对容差阈值。

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

这是我们工具箱中用于检查转换后的模块是否按预期行为与参考实现相比的第一个工具。

调试生成的代码#

由于 FX 在 GraphModule 上生成 forward() 函数,因此使用传统的调试技术,如 print 语句或 pdb,并不那么直接。幸运的是,我们有几种技术可用于调试生成的代码。

使用 pdb#

调用 pdb 进入正在运行的程序。虽然表示 Graph 的代码不在任何源文件中,但在调用前向传播时,我们仍然可以通过 pdb 手动进入它。


import torch
import torch.fx
import torchvision.models as models

def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(inp)
    # Transformation logic here
    # <...>

    # Return new Module
    return fx.GraphModule(inp, graph)

my_module = models.resnet18()
my_module_transformed = my_pass(my_module)

input_value = torch.randn(5, 3, 224, 224)

# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()

my_module_transformed(input_value)

使用 GraphModule 中的 to_folder 函数#

GraphModule.to_folder()GraphModule 中的一个方法,它允许您将生成的 FX 代码转储到一个文件夹。虽然像 打印生成的代码 中那样将前向传递复制到代码中通常就足够了,但使用 to_folder 检查模块和参数可能更容易。


m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

运行上述示例后,我们可以查看 foo/module.py 中的代码,并根据需要进行修改(例如,添加 print 语句或使用 pdb)来调试生成的代码。

调试转换#

现在我们已经确定某个转换产生了不正确的代码,是时候调试转换本身了。首先,我们将检查文档中的 符号追踪的局限性 部分。一旦我们确认追踪按预期工作,目标就是弄清楚在 GraphModule 转换过程中出了什么问题。 编写转换 中可能有一个快速答案,但如果没有,有几种方法可以检查我们的追踪模块。


# Sample Module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

# Create an instance of `M`
m = M()

# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)

# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
    add = x + y;  x = y = None
    return add
"""

# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    return add
"""

# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add>  (x, y)  {}
output         output  output                   (add,)  {}
"""

使用上述实用函数,我们可以比较我们追踪的模块在应用转换之前和之后的状态。有时,简单的目视比较足以追溯错误。如果仍然不清楚问题所在,调试器如 pdb 是一个不错的选择。

基于上面的例子,考虑以下代码:


# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    # Get the Graph from our traced Module
    g = tracer_class().trace(module)

    """
    Transformations on `g` go here
    """

    return fx.GraphModule(module, g)

# Transform the Graph
transformed = transform_graph(traced)

# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

使用上面的例子,假设 print(traced) 的调用显示我们的转换存在错误。我们想使用调试器找出问题所在。我们启动一个 pdb 会话。我们可以通过在 transform_graph(traced) 上设置断点,然后按 s “步入” transform_graph(traced) 的调用来查看转换过程中发生的情况。

我们也可以尝试编辑 print_tabular 方法来打印节点在图中的不同属性。(例如,我们可能想查看节点的 input_nodesusers。)

可用调试器#

最常见的 Python 调试器是 pdb。您可以通过在命令行中键入 python -m pdb FILENAME.py 来以“调试模式”启动程序,其中 FILENAME 是您要调试的文件名。之后,您可以使用 pdb 调试命令逐步执行正在运行的程序。通常,在启动 pdb 时设置一个断点(b LINE-NUMBER),然后调用 c 来运行程序直到该点。这可以避免您必须逐行执行(使用 sn)才能到达要检查的代码部分。或者,您可以将 import pdb; pdb.set_trace() 放在您想中断的行之前。如果您添加了 pdb.set_trace(),程序在运行时会自动进入调试模式。(换句话说,您只需在命令行中键入 python FILENAME.py 而不是 python -m pdb FILENAME.py。)一旦您的文件在调试模式下运行,您就可以使用某些命令逐步执行代码并检查程序的内部状态。在线有许多关于 pdb 的优秀教程,包括 RealPython 的 “Python 调试与 Pdb”

PyCharm 或 VSCode 等 IDE 通常内置有调试器。在您的 IDE 中,您可以选择 a) 通过在 IDE 中打开终端窗口(例如,VSCode 中的 View → Terminal)来使用 pdb,或者 b) 使用内置调试器(通常是 pdb 的图形包装器)。

符号追踪的局限性#

FX 使用一种称为符号追踪(也称为 符号执行)的系统,以可转换/可分析的形式捕获程序的语义。该系统是追踪的,因为它执行程序(实际上是一个 torch.nn.Module 或函数)来记录操作。它是符号的,因为在此执行期间流经程序的数据不是真实数据,而是符号(在 FX 术语中称为 Proxy)。

尽管符号追踪适用于大多数神经网络代码,但它也有一些局限性。

动态控制流#

符号追踪的主要限制是它目前不支持动态控制流。也就是说,循环或 if 语句,其条件可能取决于程序的输入值。

例如,让我们检查以下程序:


def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:

        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
    <...>
    File "dyn.py", line 6, in func_to_trace
    if x.sum() > 0:
    File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
    File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

if 语句的条件依赖于 x.sum() 的值,而 x.sum() 又依赖于函数输入 x 的值。由于 x 可以改变(即,如果您将新的输入张量传递给追踪的函数),这就是动态控制流。回溯会沿着您的代码向上追溯,以显示这种情况发生的位置。

静态控制流#

另一方面,所谓的静态控制流是受支持的。静态控制流是循环或 if 语句,其值在调用之间不会改变。通常,在 PyTorch 程序中,这种控制流出现在根据超参数对模型的架构做出决策的代码中。作为一个具体的例子:


import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # This if-statement is so-called static control flow.
        # Its condition does not depend on any input values
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

if 语句 if self.do_activation 不依赖于任何函数输入,因此它是静态的。 do_activation 可以被认为是超参数,并且具有不同该参数值的 MyModule 实例的不同追踪会产生不同的代码。这是一个有效的模式,并且受到符号追踪的支持。

许多动态控制流的实例实际上是静态控制流。这些实例可以通过消除对输入值的依赖关系来支持符号追踪,例如,通过将值移到 Module 属性中,或在符号追踪期间将具体值绑定到参数。


def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

fx.symbolic_trace(f, concrete_args={'flag': True})

对于真正的动态控制流,包含这些代码的程序部分可以被追踪为对 Method(参见 使用 Tracer 类定制追踪)或函数(参见 wrap())的调用,而不是追踪其内部。

torch 函数#

FX 使用 __torch_function__ 作为它拦截调用的机制(有关更多信息,请参阅 技术概述)。某些函数,例如内置 Python 函数或 math 模块中的函数,不受 __torch_function__ 的覆盖,但我们仍然希望在符号追踪中捕获它们。例如:


import torch
import torch.fx
from math import sqrt

def normalize(x):
    """
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
    <...>
    File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
    File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

错误告诉我们内置函数 len 不受支持。我们可以使用 wrap() API 使诸如 len 之类的函数在追踪中被记录为直接调用。


torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

使用 Tracer 类定制追踪#

Tracer 类是 symbolic_trace 实现的基础。通过继承 Tracer 可以自定义追踪行为,如下所示:


class MyCustomTracer(torch.fx.Tracer):
    # Inside here you can override various methods
    # to customize tracing. See the `Tracer` API
    # reference
    pass


# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

叶子模块#

叶子模块是出现在符号追踪中的模块,而不是被追踪穿透的模块。默认的叶子模块集是标准的 torch.nn 模块实例集。例如:


class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

通过覆盖 Tracer.is_leaf_module() 可以自定义叶子模块集。

杂项#

  • 张量构造函数(例如 torch.zerostorch.onestorch.randtorch.randntorch.sparse_coo_tensor)目前不可追踪。

    • 确定性构造函数(zerosones)可以使用,并且它们生成的值将作为常量嵌入到追踪中。如果这些构造函数的参数引用了动态输入大小,这才是有问题的。在这种情况下,ones_likezeros_like 可能是可行的替代方案。

    • 非确定性构造函数(randrandn)将嵌入一个随机值到追踪中。这很可能不是预期的行为。一种变通方法是将 torch.randn 包装在 torch.fx.wrap 函数中,然后调用它。

    
    @torch.fx.wrap
    def torch_randn(x, shape):
        return torch.randn(shape)
    
    def f(x):
        return x + torch_randn(x, 5)
    fx.symbolic_trace(f)
    
    • 此行为可能在将来的版本中修复。

  • 类型注解

    • Python 3 风格的类型注解(例如 func(x : torch.Tensor, y : int) -> torch.Tensor)是受支持的,并且将由符号追踪保留。

    • Python 2 风格的注释类型注解 # type: (torch.Tensor, int) -> torch.Tensor 目前不受支持。

    • 函数内局部名称的注解目前不受支持。

  • 关于 training 标志和子模块的注意事项

    • 在使用 torch.nn.functional.dropout 等函数时,通常会将 training 参数作为 self.training 传入。在 FX 追踪期间,这很可能会作为常量值被烘焙进去。

    
    import torch
    import torch.fx
    
    class DropoutRepro(torch.nn.Module):
        def forward(self, x):
        return torch.nn.functional.dropout(x, training=self.training)
    
    
    traced = torch.fx.symbolic_trace(DropoutRepro())
    print(traced.code)
    """
    def forward(self, x):
        dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
        return dropout
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    """
    AssertionError: Tensor-likes are not close!
    
    Mismatched elements: 15 / 15 (100.0%)
    Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
    Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
    """
    
    • 但是,当使用标准的 nn.Dropout() 子模块时,training 标志是封装的,并且—由于 nn.Module 对象模型的保留—可以更改。

    
    class DropoutRepro2(torch.nn.Module):
        def __init__(self):
        super().__init__()
        self.drop = torch.nn.Dropout()
    
        def forward(self, x):
        return self.drop(x)
    
    traced = torch.fx.symbolic_trace(DropoutRepro2())
    print(traced.code)
    """
    def forward(self, x):
        drop = self.drop(x);  x = None
        return drop
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    
  • 因此,请考虑将与 training 标志动态交互的模块标记为叶子模块。

API 参考#

torch.fx.symbolic_trace(root, concrete_args=None)[source]#

符号追踪 API

给定一个 nn.Module 或函数实例 root,此函数将返回一个 GraphModule,该 GraphModule 通过记录追踪 root 时看到的_操作来构建。

concrete_args 允许您部分特化函数,无论是为了消除控制流还是数据结构。

例如

def f(a, b):
    if b == True:
        return a
    else:
        return a * 2

FX 通常无法追踪此内容,因为其中存在控制流。但是,我们可以使用concrete_args 来特化b 的值以追踪此内容。

f = fx.symbolic_trace(f, concrete_args={"b": False})
assert f(3, False) == 6

请注意,尽管您仍然可以传入 b 的不同值,但它们将被忽略。

我们还可以使用concrete_args 来消除函数中的数据结构处理。这将使用 pytrees 来展平您的输入。为避免过度特化,请为不应特化的值传入 fx.PH。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out


f = fx.symbolic_trace(
    f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}
)
assert f({"a": 1, "b": 2, "c": 4}) == 7
参数
  • root (Union[torch.nn.Module, Callable]) – 要追踪并转换为 Graph 表示的模块或函数。

  • concrete_args (Optional[Dict[str, any]]) – 要部分特化的输入

返回

一个由 root 记录的操作构建的模块。

返回类型

GraphModule

注意

此 API 的向后兼容性已得到保证。

torch.fx.wrap(fn_or_name)[source]#

此函数可以在模块级别作用域调用,以将 fn_or_name 注册为“叶子函数”。“叶子函数”将在 FX 追踪中被保留为 CallFunction 节点,而不是被追踪穿透。

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y


torch.fx.wrap("my_custom_function")


def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

此函数也可以用作装饰器

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

包装后的函数可以被认为是“叶子函数”,类似于“叶子模块”的概念,即它们是在 FX 追踪中保留为调用的函数,而不是被追踪穿透。

参数

fn_or_name (Union[str, Callable]) – 当调用函数或全局函数名时要插入图中的函数或全局函数名。

注意

此 API 的向后兼容性已得到保证。

class torch.fx.GraphModule(*args, **kwargs)[source]#

GraphModule 是从 fx.Graph 生成的 nn.Module。Graphmodule 具有 graph 属性,以及从该 graph 生成的 codeforward 属性。

警告

graph 被重新赋值时,codeforward 将被自动重新生成。但是,如果您在未重新赋值 graph 属性本身的情况下编辑 graph 的内容,则必须调用 recompile() 来更新生成的代码。

注意

此 API 的向后兼容性已得到保证。

__init__(root, graph, class_name='GraphModule')[source]#

构造一个 GraphModule。

参数
  • root (Union[torch.nn.Module, Dict[str, Any]) – root 可以是 nn.Module 实例,也可以是映射字符串到任何属性类型的 Dict。如果 root 是 Module,那么 Graph 的 Nodes 的 target 字段中对 Module 相关对象的引用(通过限定名)将从 root 的 Module 层次结构中的相应位置复制到 GraphModule 的模块层次结构中。如果 root 是 dict,那么在 Node 的 target 中找到的限定名将直接在 dict 的键中查找。由 Dict 映射的对象将被复制到 GraphModule 的模块层次结构中的相应位置。

  • graph (Graph) – graph 包含此 GraphModule 应使用的节点来生成代码。

  • class_name (str) – name 表示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息都将报告为源自 GraphModule。将其设置为 root 的原始名称或在您的转换上下文中具有意义的名称可能很有用。

注意

此 API 的向后兼容性已得到保证。

add_submodule(target, m)[source]#

将给定的子模块添加到 self

如果 target 的子路径尚不存在,则会在此安装空模块。

参数
  • target (str) – 新子模块的完全限定名称(请参阅 nn.Module.get_submodule 中的示例,了解如何指定完全限定名称)。

  • m (Module) – 子模块本身;我们要安装到当前模块中的实际对象。

返回

子模块是否可以被插入。为了

使此方法返回 True,由 target 表示的链中的每个对象必须 a) 尚不存在,或 b) 引用一个 nn.Module(而不是参数或其他属性)。

返回类型

布尔值

注意

此 API 的向后兼容性已得到保证。

property code: str#

返回从此 GraphModule 底层的 Graph 生成的 Python 代码。

delete_all_unused_submodules()[source]#

self 中删除所有未使用的子模块。

当满足以下任一条件时,模块被认为是“已使用”:1. 它有已使用的子模块 2. 它的 forward 被直接通过 call_module 节点调用 3. 它有一个非 Module 属性,该属性从 get_attr 节点使用。

此方法可用于清理 nn.Module,而无需手动对每个未使用的子模块调用 delete_submodule

注意

此 API 的向后兼容性已得到保证。

delete_submodule(target)[source]#

self 中删除给定的子模块。

如果 target 不是有效目标,则不会删除该模块。

参数

target (str) – 新子模块的完全限定名称(请参阅 nn.Module.get_submodule 中的示例,了解如何指定完全限定名称)。

返回

目标字符串是否引用了

我们要删除的子模块。返回值为 False 意味着 target 不是对子模块的有效引用。

返回类型

布尔值

注意

此 API 的向后兼容性已得到保证。

property graph: Graph#

返回此 GraphModule 底层的 Graph

print_readable(print_output=True, include_stride=False, include_device=False, colored=False, *, fast_sympy_print=False, expanded_def=False)[source]#

返回当前 GraphModule 及其子 GraphModule 生成的 Python 代码。

警告

此 API 是实验性的,并且向后兼容。

recompile()[source]#

从其 graph 属性重新编译此 GraphModule。在编辑了包含的 graph 之后应调用此方法,否则此 GraphModule 的生成代码将过时。

注意

此 API 的向后兼容性已得到保证。

返回类型

PythonCode

to_folder(folder, module_name='FxModule')[source]#
将模块转储到 folder 中,使用 module_name,以便可以

使用 from <folder> import <module_name> 导入。

参数 (Args)

folder (Union[str, os.PathLike]): 要将代码写入的文件夹。

module_name (str): 在

写出代码时用于 Module 的顶级名称。

警告

此 API 是实验性的,并且向后兼容。

class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#

Graph 是 FX 中间表示使用的主要数据结构。它由一系列 Node 组成,每个 Node 代表一个调用点(或其他语法结构)。Node 的列表共同构成一个有效的 Python 函数。

例如,以下代码:

import torch
import torch.fx


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(
            torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3
        )


m = MyModule()
gm = torch.fx.symbolic_trace(m)

将产生以下图:

print(gm.graph)
graph(x):
    %linear_weight : [num_users=1] = self.linear.weight
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

有关 Graph 中操作语义的信息,请参阅 Node

注意

此 API 的向后兼容性已得到保证。

__init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#

构造一个空的 Graph。

注意

此 API 的向后兼容性已得到保证。

call_function(the_function, args=None, kwargs=None, type_expr=None, name=None)[source]#

call_function Node 插入到 Graph 中。call_function 节点表示对由 the_function 指定的 Python 可调用对象的调用。

参数
  • the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 运算符、Python 函数或 builtinsoperator 命名空间中的成员。

  • args (Optional[Tuple[Argument, ...]]) – 要传递给被调用函数的 positional 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用函数的 keyword 参数。

  • type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出将具有的 Python 类型。

  • name (Optional[str]) – 节点的名称。如果未指定,则设置为 None。

返回

新创建并插入的 call_function 节点。

返回类型

Node

注意

此方法的插入点和类型表达式规则与 Graph.create_node() 相同。

注意

此 API 的向后兼容性已得到保证。

call_method(method_name, args=None, kwargs=None, type_expr=None)[source]#

call_method Node 插入到 Graph 中。call_method 节点表示对 args 的第 0 个元素上的给定方法进行的调用。

参数
  • method_name (str) – 要应用于 self 参数的方法名称。例如,如果 args[0] 是表示 TensorNode,那么要在该 Tensor 上调用 relu(),请将 relu 传递给 method_name

  • args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的 positional 参数。请注意,这应该包含一个 self 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的 keyword 参数。

  • type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出将具有的 Python 类型。

返回

新创建并插入的 call_method 节点。

返回类型

Node

注意

此方法的插入点和类型表达式规则与 Graph.create_node() 相同。

注意

此 API 的向后兼容性已得到保证。

call_module(module_name, args=None, kwargs=None, type_expr=None)[source]#

将一个call_module Node 插入到Graph 中。call_module 节点表示在Module 层次结构中调用Module 的 forward() 函数。

参数
  • module_name (str) – 要调用的ModuleModule 层次结构中的限定名称。例如,如果跟踪的Module 包含一个名为foo 的子模块,该子模块又包含一个名为bar 的子模块,则应将限定名称foo.bar 作为module_name 传递,以调用该模块。

  • args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的 positional 参数。请注意,这不应包含self 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的 keyword 参数。

  • type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出将具有的 Python 类型。

返回

新创建并插入的call_module 节点。

返回类型

Node

注意

此方法的插入点和类型表达式规则与 Graph.create_node() 相同。

注意

此 API 的向后兼容性已得到保证。

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source]#

创建一个Node 并将其插入到当前插入点的Graph 中。请注意,当前插入点可以通过Graph.inserting_before()Graph.inserting_after() 来设置。

参数
  • op (str) – 此 Node 的操作码。可以是 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’。这些操作码的语义在Graph 文档字符串中有所描述。

  • args (Optional[Tuple[Argument, ...]]) – 是此节点的参数元组。

  • kwargs (Optional[Dict[str, Argument]]) – 此 Node 的 kwargs。

  • name (Optional[str]) – Node 的可选字符串名称。这会影响在生成的 Python 代码中分配给该值的名称。

  • type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出将具有的 Python 类型。

返回

新创建并插入的节点。

返回类型

Node

注意

此 API 的向后兼容性已得到保证。

eliminate_dead_code(is_impure_node=None)[source]#

从图中删除所有死代码,基于每个节点的 use 数量以及节点是否具有任何 side effects。调用前必须对图进行拓扑排序。

参数
  • is_impure_node (Optional[Callable[[Node], bool]]) – 一个返回

  • None (节点是否为 impure。如果是) –

  • to (则默认行为是) –

  • Node.is_impure. (使用) –

返回

图是否因该过程而改变。

返回类型

布尔值

示例

在消除死代码之前,下面 a = x + 1 中的 a 没有 use,因此可以从图中消除而不产生任何影响。

def forward(self, x):
    a = x + 1
    return x + self.attr_1

消除死代码后,a = x + 1 已被删除,其余的 forward 保持不变。

def forward(self, x):
    return x + self.attr_1

警告

死代码消除有一些启发式方法来避免删除具有 side effects 的节点 (参见 Node.is_impure),但总的来说覆盖率非常差,因此您应该假定调用此方法不是安全的,除非您知道您的 FX 图完全由函数式操作组成,或者您提供了自己的自定义函数来检测具有 side effects 的节点。

注意

此 API 的向后兼容性已得到保证。

erase_node(to_erase)[source]#

Graph 中删除一个Node。如果该节点在Graph 中仍有 use,则会引发异常。

参数

to_erase (Node) – 要从Graph 中删除的Node

注意

此 API 的向后兼容性已得到保证。

find_nodes(*, op, target=None, sort=True)[source]#

允许快速查询节点。

参数
  • op (str) – 操作的名称。

  • target (Optional[Target]) – 节点的 target。对于 call_function,target 是必需的。对于其他 op,target 是可选的。

  • sort (bool) – 是否按节点在图上出现的顺序返回节点。

返回

具有请求的操作和 target 的节点的可迭代对象。

警告

此 API 是实验性的,并且向后兼容。

get_attr(qualified_name, type_expr=None)[source]#

将一个get_attr 节点插入到 Graph 中。get_attr Node 表示从Module 层次结构中获取一个属性。

参数
  • qualified_name (str) – 要检索的属性的完全限定名称。例如,如果跟踪的 Module 包含一个名为foo 的子模块,该子模块包含一个名为bar 的子模块,该子模块有一个名为baz 的属性,则应将限定名称foo.bar.baz 作为qualified_name 传递。

  • type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出将具有的 Python 类型。

返回

新创建并插入的get_attr 节点。

返回类型

Node

注意

此方法与Graph.create_node 具有相同的插入点和类型表达式规则。

注意

此 API 的向后兼容性已得到保证。

graph_copy(g, val_map, return_output_node=False)[source]#

将给定图中的所有节点复制到self 中。

参数
  • g (Graph) – 要从中复制 Nodes 的源图。

  • val_map (Dict[Node, Node]) – 一个字典,将填充从g 中的节点到self 中的节点的映射。请注意,可以传入带有值的val_map 来覆盖某些值的复制。

返回

self 中与g 中的输出值等效的值,如果g 有一个output 节点。否则为None

返回类型

Optional[Union[tuple[‘Argument’, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]

注意

此 API 的向后兼容性已得到保证。

inserting_after(n=None)[source]#
设置 create_node 和伴随方法将插入到图中的点。

在 `with` 语句中使用时,这将临时设置插入点,然后在 `with` 语句退出时将其恢复。

with g.inserting_after(n):
    ...  # inserting after node n
...  # insert point restored to what it was previously
g.inserting_after(n)  #  set the insert point permanently

参数 (Args)

n (Optional[Node]): 在其之前插入的节点。如果为 None,则将在图的开头之后插入。

整个图的开头。

返回

一个资源管理器,它将在 `__exit__` 时恢复插入点。

注意

此 API 的向后兼容性已得到保证。

inserting_before(n=None)[source]#
设置 create_node 和伴随方法将插入到图中的点。

在 `with` 语句中使用时,这将临时设置插入点,然后在 `with` 语句退出时将其恢复。

with g.inserting_before(n):
    ...  # inserting before node n
...  # insert point restored to what it was previously
g.inserting_before(n)  #  set the insert point permanently

参数 (Args)

n (Optional[Node]): 在其之前插入的节点。如果为 None,则将在图的开头之前插入。

整个图的开头。

返回

一个资源管理器,它将在 `__exit__` 时恢复插入点。

注意

此 API 的向后兼容性已得到保证。

lint()[source]#

运行此 Graph 的各种检查,以确保其结构良好。特别是: - 检查 Nodes 是否具有正确的归属(属于此图) - 检查 Nodes 是否按拓扑顺序出现 - 如果此 Graph 具有拥有它的 GraphModule,则检查 target 是否在该 GraphModule 中存在。

注意

此 API 的向后兼容性已得到保证。

node_copy(node, arg_transform=<function Graph.<lambda>>)[source]#

将一个节点从一个图复制到另一个图。`arg_transform` 需要将参数从 node 的图转换为 self 的图。示例

# Copying all the nodes in `g` into `new_graph`
g: torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
参数
  • node (Node) – 要复制到self 中的节点。

  • arg_transform (Callable[[Node], Argument]) – 一个函数,用于将 node 的 `args` 和 `kwargs` 中的 `Node` 参数转换为 `self` 中等效的参数。在最简单的情况下,这应该从一个将原始图中的 Nodes 映射到 `self` 的表中检索一个值。

返回类型

Node

注意

此 API 的向后兼容性已得到保证。

property nodes: list['Node']#

获取构成此 Graph 的 Nodes 列表。

注意,这个 `Node` 列表表示是一个双向链表。在迭代过程中进行修改(例如,删除一个 Node,添加一个 Node)是安全的。

返回

Nodes 的双向链表。注意,可以对此列表调用 `reversed` 来切换迭代顺序。

on_generate_code(make_transformer)[source]#

在生成 Python 代码时注册一个 transformer 函数。

参数 (Args)
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc])

一个返回代码 transformer 的函数,用于注册。此函数由 `on_generate_code` 调用以获取代码 transformer。

此函数也作为输入提供当前注册的代码 transformer(如果未注册,则为 None),以防不希望覆盖它。这对于链接代码 transformer 很有用。

返回

一个上下文管理器,当在 `with` 语句中使用时,会自动恢复先前注册的代码 transformer。

示例

gm: fx.GraphModule = ...


# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):
    return ["import pdb; pdb.set_trace()\n", *body]


# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(lambda _: insert_pdb)

# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(
    lambda current_trans: (
        lambda body: insert_pdb(
            current_trans(body) if current_trans else body
        )
    )
)

gm.recompile()
gm(*inputs)  # drops into pdb

此函数还可以用作上下文管理器,其优点是可以自动恢复先前注册的代码 transformer。

# ... continue from previous example

with gm.graph.on_generate_code(lambda _: insert_pdb):
    # do more stuff with `gm`...
    gm.recompile()
    gm(*inputs)  # drops into pdb

# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you
# run next `recompile()`).

警告

此 API 是实验性的,并且向后兼容。

output(result, type_expr=None)[source]#

将一个 `output` Node 插入到Graph 中。`output` 节点表示 Python 代码中的 `return` 语句。`result` 是应返回的值。

参数
  • result (Argument) – 要返回的值。

  • type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出将具有的 Python 类型。

注意

此方法与Graph.create_node 具有相同的插入点和类型表达式规则。

注意

此 API 的向后兼容性已得到保证。

output_node()[source]#

警告

此 API 是实验性的,并且向后兼容。

返回类型

Node

placeholder(name, type_expr=None, default_value)[source]#

将一个 `placeholder` 节点插入到 Graph 中。`placeholder` 表示函数输入。

参数
  • name (str) – 输入值的名称。这对应于此 `Graph` 所代表的函数的 positional 参数的名称。

  • type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。在某些情况下,这对于正确的代码生成是必需的(例如,当函数随后在 TorchScript 编译中使用时)。

  • default_value (Any) – 此函数参数应采用的默认值。注意:为了允许将 `None` 作为默认值,应将 `inspect.Signature.empty` 作为此参数传递,以指定该参数*没有*默认值。

返回类型

Node

注意

此方法与Graph.create_node 具有相同的插入点和类型表达式规则。

注意

此 API 的向后兼容性已得到保证。

print_tabular()[source]#

以表格格式打印图的中间表示。请注意,此 API 需要安装 `tabulate` 模块。

注意

此 API 的向后兼容性已得到保证。

process_inputs(*args)[source]#

处理 args,以便可以将其传递给 FX 图。

警告

此 API 是实验性的,并且向后兼容。

process_outputs(out)[source]#

警告

此 API 是实验性的,并且向后兼容。

python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False, expanded_def=False)[source]#

将此 `Graph` 转换为有效的 Python 代码。

参数

root_module (str) – 用于查找限定名称 target 的根模块的名称。通常是 ‘self’。

返回

src:表示对象的 Python 源代码 globals:`src` 中的全局名称字典 -> 它们引用的对象。

返回类型

一个 PythonCode 对象,包含两个字段。

注意

此 API 的向后兼容性已得到保证。

set_codegen(codegen)[source]#

警告

此 API 是实验性的,并且向后兼容。

class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source]#

`Node` 是表示 `Graph` 中单个操作的数据结构。在大多数情况下,Nodes 表示对各种实体的调用,例如运算符、方法和 Modules(一些例外包括指定函数输入和输出的节点)。每个 `Node` 都有一个由其 `op` 属性指定的函数。`op` 每个值的 `Node` 语义如下:

  • `placeholder` 表示函数输入。`name` 属性指定该值将采用的名称。`target` 同样是参数的名称。`args` 持有:1) 空,或 2) 一个表示函数输入的默认参数的单个参数。`kwargs` 是不关心的。

  • `get_attr` 从模块层次结构中检索参数。`name` 同样是分配给获取结果的名称。`target` 是参数在模块层次结构中的位置的完全限定名称。`args` 和 `kwargs` 是不关心的。

  • `call_function` 将一个自由函数应用于某些值。`name` 同样是分配给该值的名称。`target` 是要应用的函数。`args` 和 `kwargs` 表示函数的参数,遵循 Python 调用约定。

  • `call_module` 将模块层次结构中的模块的 `forward()` 方法应用于给定参数。`name` 同上。`target` 是要调用的模块在模块层次结构中的完全限定名称。`args` 和 `kwargs` 表示调用模块的参数,*不包括 self 参数*。

  • `call_method` 调用一个值上的方法。`name` 同上。`target` 是应用于 `self` 参数的方法的字符串名称。`args` 和 `kwargs` 表示调用模块的参数,*包括 self 参数*。

  • `output` 在其 `args[0]` 属性中包含跟踪函数的输出。这对应于 Graph 打印输出中的“return”语句。

注意

此 API 的向后兼容性已得到保证。

property all_input_nodes: list['Node']#

返回作为此 Node 输入的所有 Nodes。这相当于迭代 `args` 和 `kwargs` 并仅收集是 Nodes 的值。

返回

出现在此 `Node` 的 `args` 和 `kwargs` 中的 `Nodes` 列表,按该顺序。

append(x)[source]#

在图节点列表中在此节点之后插入 `x`。等同于 `self.next.prepend(x)`。

参数

x (Node) – 要放在此节点之后的节点。必须是同一图的成员。

注意

此 API 的向后兼容性已得到保证。

property args: tuple[Union[tuple['Argument', ...], collections.abc.Sequence['Argument'], collections.abc.Mapping[str, 'Argument'], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...]#

此 `Node` 的参数元组。参数的解释取决于节点的 op。有关更多信息,请参阅 Node 文档字符串。

允许对此属性进行赋值。所有 use 和 users 的计数都会在赋值时自动更新。

format_node(placeholder_names=None, maybe_return_typename=None, *, include_tensor_metadata=False)[source]#

返回 `self` 的描述性字符串表示。

此方法可作为调试实用程序使用,无需参数。

此函数还在 `Graph` 的 `__str__` 方法中用作内部助手。`placeholder_names` 和 `maybe_return_typename` 中的字符串共同构成了此 Graph 外部 GraphModule 中自动生成的 `forward` 函数的签名。`placeholder_names` 和 `maybe_return_typename` 不得在其他地方使用。

参数
  • placeholder_names (Optional[list[str]]) – 一个列表,将存储表示生成的 `forward` 函数中 placeholder 的格式化字符串。仅内部使用。

  • maybe_return_typename (Optional[list[str]]) – 一个单元素列表,将存储表示生成的 `forward` 函数输出的格式化字符串。仅内部使用。

  • include_tensor_metadata (bool) – 是否包含张量元数据。

返回

如果 1) 我们正在使用 `format_node` 作为内部助手

在 `Graph` 的 `__str__` 方法中,并且 2) `self` 是一个 placeholder Node,则返回 `None`。否则,返回当前 Node 的描述性字符串表示。

返回类型

str

注意

此 API 的向后兼容性已得到保证。

insert_arg(idx, arg)[source]#

将一个 positional 参数插入到具有给定索引的参数列表中。

参数
  • idx (int) – 要插入之前的 `self.args` 中的元素的索引。

  • arg (Argument) – 要插入 `args` 中的新参数值。

注意

此 API 的向后兼容性已得到保证。

is_impure(impure_random=True)[source]#

返回此 op 是否为 impure,即如果其 op 是 placeholder 或 output,或者是一个 impure 的 call_function 或 call_module。

参数

impure_random (bool) – 是否将 rand op 视为 impure。

返回

op 是 impure 还是不是。

返回类型

布尔值

警告

此 API 是实验性的,并且向后兼容。

property kwargs: dict[str, Union[tuple['Argument', ...], collections.abc.Sequence['Argument'], collections.abc.Mapping[str, 'Argument'], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]#

此 `Node` 的关键字参数字典。参数的解释取决于节点的 op。有关更多信息,请参阅 Node 文档字符串。

允许对此属性进行赋值。所有 use 和 users 的计数都会在赋值时自动更新。

property next: Node#

返回链表中下一个 `Node`。

返回

链表中下一个 `Node`。

normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source]#

返回 Python target 的标准化参数。这意味着 `args/kwargs` 将与 module/functional 的签名匹配,并且如果 `normalize_to_only_use_kwargs` 为 true,则将仅按位置返回 kwargs。还将填充默认值。不支持仅 positional 参数或 varargs 参数。

支持模块调用。

可能需要 `arg_types` 和 `kwarg_types` 来区分重载。

参数
  • root (torch.nn.Module) – 用于解析模块 target 的模块。

  • arg_types (Optional[Tuple[Any]]) – arg 的 arg 类型元组。

  • kwarg_types (Optional[Dict[str, Any]]) – kwargs 的 arg 类型字典。

  • normalize_to_only_use_kwargs (bool) – 是否规范化为仅使用 kwargs。

返回

返回 NamedTuple ArgsKwargsPair,如果未成功则返回 `None`。

返回类型

Optional[ArgsKwargsPair]

警告

此 API 是实验性的,并且向后兼容。

prepend(x)[source]#

在图节点列表中在此节点之前插入 x。示例

Before: p -> self
        bx -> x -> ax
After:  p -> x -> self
        bx -> ax
参数

x (Node) – 要放在此节点之前的节点。必须是同一图的成员。

注意

此 API 的向后兼容性已得到保证。

property prev: Node#

返回链表中上一个 `Node`。

返回

链表中上一个 `Node`。

replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[source]#

将 Graph 中 `self` 的所有 use 替换为 Node `replace_with`。

参数
  • replace_with (Node) – 用于替换 `self` 的所有 use 的节点。

  • delete_user_cb (Callable) – 用于确定是否应删除 self 节点的给定 use 的回调函数。

  • propagate_meta (bool) – 是否将原始节点的 .meta 字段中的所有属性复制到替换节点。为安全起见,仅当替换节点尚不具有 .meta 字段时才有效。

返回

在此更改被执行的 Nodes 列表。

返回类型

list[‘Node’]

注意

此 API 的向后兼容性已得到保证。

replace_input_with(old_input, new_input)[source]#

遍历 `self` 的输入节点,并将 `old_input` 的所有实例替换为 `new_input`。

参数
  • old_input (Node) – 要被替换的旧输入节点。

  • new_input (Node) – 用于替换 `old_input` 的新输入节点。

注意

此 API 的向后兼容性已得到保证。

property stack_trace: Optional[str]#

返回跟踪期间记录的 Python 堆栈跟踪(如果有)。当使用 fx.Tracer 进行跟踪时,此属性通常由 `Tracer.create_proxy` 填充。要在跟踪期间记录堆栈跟踪以进行调试,请在 `Tracer` 实例上设置 `record_stack_traces = True`。当使用 dynamo 进行跟踪时,此属性将由 `OutputGraph.create_proxy` 默认填充。

stack_trace 的字符串末尾将是最内层的帧。

update_arg(idx, arg)[source]#

使用新值 `arg` 更新现有 positional 参数。调用后,`self.args[idx] == arg`。

参数
  • idx (int) – 要更新的元素在 `self.args` 中的索引。

  • arg (Argument) – 要写入 `args` 中的新参数值。

注意

此 API 的向后兼容性已得到保证。

update_kwarg(key, arg)[source]#

使用新值 `arg` 更新现有关键字参数。调用后,`self.kwargs[key] == arg`。

参数
  • key (str) – 要更新的元素在 `self.kwargs` 中的键。

  • arg (Argument) – 要写入 `kwargs` 中的新参数值。

注意

此 API 的向后兼容性已得到保证。

class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source]#

`Tracer` 是实现 `torch.fx.symbolic_trace` 的符号跟踪功能的类。调用 `symbolic_trace(m)` 等同于 `Tracer().trace(m)`。

Tracer 可以被子类化以覆盖跟踪过程的各种行为。可以覆盖的不同行为在此类方法的文档字符串中进行了描述。

注意

此 API 的向后兼容性已得到保证。

call_module(m, forward, args, kwargs)[source]#

指定此 `Tracer` 在遇到对 `nn.Module` 实例的调用时行为的方法。

默认情况下,行为是检查被调用的模块是否为叶子模块(通过 `is_leaf_module`)。如果是,则在 `Graph` 中发出一个指向 `m` 的 `call_module` 节点。否则,正常调用 `Module`,跟踪其 `forward` 函数中的操作。

可以覆盖此方法来实现,例如,创建嵌套的跟踪 GraphModules,或在跟踪跨越 `Module` 边界时执行任何其他所需行为。

参数
  • m (Module) – 正在发出调用的模块。

  • forward (Callable) – 要调用的 `Module` 的 forward() 方法。

  • args (Tuple) – 模块调用处的 args。

  • kwargs (Dict) – 模块调用处的 kwargs。

返回

来自 Module 调用的返回值。在发出 `call_module` 节点的情况下,这是一个 `Proxy` 值。否则,它是从 `Module` 调用返回的任何值。

返回类型

任何

注意

此 API 的向后兼容性已得到保证。

create_arg(a)[source]#

用于指定此 `Tracer` 在准备作为 `Graph` 中节点参数的值时的行为的方法。

默认行为包括:

  1. 迭代集合类型(例如,tuple、list、dict),并递归调用 `create_args` 来处理元素。

  2. 给定 Proxy 对象,返回对底层 IR `Node` 的引用。

  3. 给定非 Proxy Tensor 对象,为各种情况发出 IR。

    • 对于 Parameter,发出指向该 Parameter 的 `get_attr` 节点。

    • 对于非 Parameter Tensor,将该 Tensor 存储在一个特殊的属性中,该属性指向该属性。

可以覆盖此方法以支持更多类型。

参数

a (Any) – 将作为 `Argument` 发出到 `Graph` 中的值。

返回

`a` 转换为适当 `Argument` 的值。

返回类型

Argument

注意

此 API 的向后兼容性已得到保证。

create_args_for_root(root_fn, is_module, concrete_args=None)[source]#

创建与根模块(`root` Module)的签名对应的 `placeholder` 节点。此方法内省 root 的签名并相应地发出这些节点,还支持 `*args` 和 `**kwargs`。

警告

此 API 是实验性的,并且向后兼容。

create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]#

根据目标、参数、关键字参数和名称插入图节点。

可以覆盖此方法以执行额外的检查、验证或修改用于节点创建的值。例如,有人可能希望不允许记录就地操作。

注意

此 API 的向后兼容性已得到保证。

返回类型

Node

create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]#

根据给定的参数创建 Node,然后将 Node 包装在 Proxy 对象中返回。

如果 kind = ‘placeholder’,那么我们正在创建一个表示函数参数的 Node。如果需要编码默认参数,我们使用 `args` 元组。对于 `placeholder` Nodes,`args` 否则为空。

注意

此 API 的向后兼容性已得到保证。

get_fresh_qualname(prefix)[source]#

为前缀获取一个新名称并返回它。此函数确保它不会与图上的现有属性冲突。

注意

此 API 的向后兼容性已得到保证。

返回类型

str

getattr(attr, attr_val, parameter_proxy_cache)[source]#

指定此 `Tracer` 在调用 `nn.Module` 实例上的 getattr 时行为的方法。

默认情况下,行为是返回属性的代理值。它还将代理值存储在 `parameter_proxy_cache` 中,以便将来的调用将重用该代理而不是创建新代理。

可以覆盖此方法来实现,例如,在查询参数时不返回代理。

参数
  • attr (str) – 查询的属性名称。

  • attr_val (Any) – 属性的值。

  • parameter_proxy_cache (Dict[str, Any]) – 属性名到代理的缓存。

返回

来自 getattr 调用的返回值。

警告

此 API 是实验性的,并且向后兼容。

is_leaf_module(m, module_qualified_name)[source]#

指定给定 `nn.Module` 是否为“叶子”模块的方法。

叶子模块是 IR 中出现的原子单元,由 `call_module` 调用引用。默认情况下,PyTorch 标准库命名空间(torch.nn)中的模块是叶子模块。所有其他模块都将被跟踪并记录其组成的 op,除非通过此参数另有指定。

参数
  • m (Module) – 查询的模块。

  • module_qualified_name (str) – 到此模块根的路径。例如,如果模块层次结构中的子模块 `foo` 包含子模块 `bar`,该子模块又包含子模块 `baz`,则该模块将在此处显示为限定名称 `foo.bar.baz`。

返回类型

布尔值

注意

此 API 的向后兼容性已得到保证。

iter(obj)[source]#
在迭代代理对象时调用,例如

在控制流中使用时。通常我们不知道该怎么做,因为我们不知道代理的值,但是自定义跟踪器可以通过 create_node 将更多信息附加到图节点,并可以选择返回一个迭代器。

注意

此 API 的向后兼容性已得到保证。

返回类型

迭代器。

keys(obj)[source]#
在代理对象调用 keys() 方法时调用。

这就是当对代理对象执行 ** 时发生的情况。这应该返回一个迭代器,如果 ** 在您的自定义跟踪器中工作的话。

注意

此 API 的向后兼容性已得到保证。

返回类型

任何

path_of_module(mod)[源]#

辅助方法,用于查找 root 的 Module 层次结构中 mod 的限定名称。例如,如果 root 有一个名为 foo 的子模块,而 foo 又有一个名为 bar 的子模块,则将 bar 传递给此函数将返回字符串 “foo.bar”。

参数

mod (str) – 要为其检索限定名称的 Module

返回类型

str

注意

此 API 的向后兼容性已得到保证。

proxy(node)[源]#

注意

此 API 的向后兼容性已得到保证。

返回类型

代理

to_bool(obj)[源]#
当代理对象被转换为布尔值时调用,例如

在控制流中使用时。通常我们不知道该怎么做,因为我们不知道代理的值,但是自定义跟踪器可以使用 create_node 向图节点附加更多信息,并可以选择返回值。

注意

此 API 的向后兼容性已得到保证。

返回类型

布尔值

trace(root, concrete_args=None)[源]#

跟踪 root 并返回相应的 FX Graph 表示。 root 可以是 nn.Module 实例或 Python 可调用对象。

请注意,调用此方法后,self.root 可能与此处传入的 root 不同。例如,当将一个自由函数传递给 trace() 时,我们将创建一个 nn.Module 实例作为根并添加嵌入的常量。

参数
  • root (Union[Module, Callable]) – 要通过跟踪的 Module 或函数。此参数向后兼容。

  • concrete_args (Optional[Dict[str, any]]) – 不应被视为代理的具体参数。此参数是实验性的,其向后兼容性有保证。

返回

一个 Graph,表示传入的 root 的语义。

返回类型

Graph

注意

此 API 的向后兼容性已得到保证。

class torch.fx.Proxy(node, tracer=None)[源]#

Proxy 对象是 Node 的包装器,它们在符号跟踪期间在程序中流动,并将它们触及的所有操作(torch 函数调用、方法调用、运算符)记录到不断增长的 FX Graph 中。

如果您正在进行图变换,您可以将自己的 Proxy 方法包装在原始 Node 上,以便可以使用重载运算符将其他内容添加到 Graph 中。

Proxy 对象不能迭代。换句话说,如果 Proxy 在循环中使用或作为 *args/**kwargs 函数参数使用,符号跟踪器将抛出错误。

通常有两种方法可以解决此问题:1. 将无法跟踪的逻辑提取到一个顶级函数中,并对其使用 fx.wrap。 2. 如果控制流是静态的(即循环次数取决于某个超参数),则可以将代码保留在原始位置并重构为类似

for i in range(self.some_hyperparameter):
    indexed_item = proxied_value[i]

有关 Proxy 内部机制的更详细说明,请参阅 torch/fx/README.md 中的“Proxy”部分。

注意

此 API 的向后兼容性已得到保证。

class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[源]#

Interpreter 按节点逐个执行 FX 图。这种模式可用于许多用途,包括编写代码转换和分析传递。

可以覆盖 Interpreter 类中的方法来定制执行行为。可覆盖方法的调用层次图

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

示例

假设我们要将所有 torch.neg 实例与 torch.sigmoid 相互替换(包括它们的 Tensor 方法等效项)。我们可以像这样继承 Interpreter

class NegSigmSwapInterpreter(Interpreter):
    def call_function(
        self, target: Target, args: Tuple, kwargs: Dict
    ) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(target, args, kwargs)

    def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
        if target == "neg":
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(target, args, kwargs)


def fn(x):
    return torch.sigmoid(x).neg()


gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
参数
  • module (torch.nn.Module) – 要执行的模块

  • garbage_collect_values (bool) – 是否在模块执行期间删除其最后一个使用的值。这可确保在执行过程中实现最佳内存使用。例如,可以通过查看 Interpreter.env 属性来禁用此项,以检查执行中的所有中间值。

  • graph (Optional[Graph]) – 如果已传递,解释器将执行此图而不是 module.graph,并使用提供的 module 参数来满足任何状态请求。

注意

此 API 的向后兼容性已得到保证。

boxed_run(args_list)[源]#

通过解释运行 module 并返回结果。这使用了“boxed”调用约定,您需要传递一个参数列表,该列表将被解释器清除。这确保输入张量能被及时释放。

注意

此 API 的向后兼容性已得到保证。

call_function(target, args, kwargs)[源]#

执行 call_function 节点并返回结果。

参数
  • target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

任何

Return

Any: 函数调用返回的值

注意

此 API 的向后兼容性已得到保证。

call_method(target, args, kwargs)[源]#

执行 call_method 节点并返回结果。

参数
  • target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

任何

Return

Any: 方法调用返回的值

注意

此 API 的向后兼容性已得到保证。

call_module(target, args, kwargs)[源]#

执行 call_module 节点并返回结果。

参数
  • target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

任何

Return

Any: 模块调用返回的值

注意

此 API 的向后兼容性已得到保证。

fetch_args_kwargs_from_env(n)[源]#

从当前执行环境中获取节点 nargskwargs 的具体值。

参数

n (Node) – 应从中获取 argskwargs 的节点。

返回

argskwargs 包含 n 的具体值。

返回类型

Tuple[Tuple, Dict]

注意

此 API 的向后兼容性已得到保证。

fetch_attr(target)[源]#

self.moduleModule 层次结构中获取一个属性。

参数

target (str) – 要获取的属性的完全限定名称

返回

属性的值。

返回类型

任何

注意

此 API 的向后兼容性已得到保证。

get_attr(target, args, kwargs)[源]#

执行 get_attr 节点。将从 self.moduleModule 层次结构中检索属性值。

参数
  • target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回

检索到的属性值

返回类型

任何

注意

此 API 的向后兼容性已得到保证。

map_nodes_to_values(args, n)[源]#

递归地遍历 args,并在当前执行环境中查找每个 Node 的具体值。

参数
  • args (Argument) – 用于查找具体值的元数据结构

  • n (Node) – args 所属的节点。这仅用于错误报告。

返回类型

Optional[Union[tuple[‘Argument’, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]

注意

此 API 的向后兼容性已得到保证。

output(target, args, kwargs)[源]#

执行 output 节点。这实际上只是检索 output 节点引用的值并返回它。

参数
  • target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回

输出节点引用的返回值

返回类型

任何

注意

此 API 的向后兼容性已得到保证。

placeholder(target, args, kwargs)[源]#

执行 placeholder 节点。请注意,这是有状态的:Interpreter 维护一个关于传递给 run 的参数的内部迭代器,此方法返回该迭代器的 next()。

参数
  • target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回

检索到的参数值。

返回类型

任何

注意

此 API 的向后兼容性已得到保证。

run(*args, initial_env=None, enable_io_processing=True)[源]#

通过解释运行 module 并返回结果。

参数
  • *args – 要运行的模块的参数,按位置顺序排列

  • initial_env (Optional[Dict[Node, Any]]) – 执行的可选初始环境。这是一个将 Node 映射到任意值的字典。例如,这可用于预先填充某些 Nodes 的结果,以便只在解释器中进行部分求值。

  • enable_io_processing (bool) – 如果为 true,我们将使用 graph 的 process_inputs 和 process_outputs 函数处理输入和输出,然后再使用它们。

返回

执行模块返回的值

返回类型

任何

注意

此 API 的向后兼容性已得到保证。

run_node(n)[源]#

运行特定节点 n 并返回结果。根据 node.op 调用 placeholder、get_attr、call_function、call_method、call_module 或 output。

参数

n (Node) – 要执行的节点

返回

执行 n 的结果

返回类型

任何

注意

此 API 的向后兼容性已得到保证。

class torch.fx.Transformer(module)[源]#

Transformer 是一种特殊的解释器,它生成新的 Module。它公开了一个 transform() 方法,该方法返回转换后的 ModuleTransformer 不需要参数即可运行,而 Interpreter 需要。 Transformer 完全以符号方式工作。

示例

假设我们要将所有 torch.neg 实例与 torch.sigmoid 相互替换(包括它们的 Tensor 方法等效项)。我们可以像这样继承 Transformer

class NegSigmSwapXformer(Transformer):
    def call_function(
        self,
        target: "Target",
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Any],
    ) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(target, args, kwargs)

    def call_method(
        self,
        target: "Target",
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Any],
    ) -> Any:
        if target == "neg":
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(target, args, kwargs)


def fn(x):
    return torch.sigmoid(x).neg()


gm = torch.fx.symbolic_trace(fn)

transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
参数

module (GraphModule) – 要转换的 Module

注意

此 API 的向后兼容性已得到保证。

call_function(target, args, kwargs)[源]#

注意

此 API 的向后兼容性已得到保证。

返回类型

任何

call_module(target, args, kwargs)[源]#

注意

此 API 的向后兼容性已得到保证。

返回类型

任何

get_attr(target, args, kwargs)[源]#

执行 get_attr 节点。在 Transformer 中,这被重写为将一个新的 get_attr 节点插入到输出图中。

参数
  • target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

代理

注意

此 API 的向后兼容性已得到保证。

placeholder(target, args, kwargs)[源]#

执行 placeholder 节点。在 Transformer 中,这被重写为将一个新的 placeholder 插入到输出图中。

参数
  • target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

代理

注意

此 API 的向后兼容性已得到保证。

transform()[源]#

转换 self.module 并返回转换后的 GraphModule

注意

此 API 的向后兼容性已得到保证。

返回类型

GraphModule

torch.fx.replace_pattern(gm, pattern, replacement)[源]#

匹配 GraphModule (gm) 的 Graph 中所有可能的非重叠的算子集及其数据依赖性 (pattern),然后将每个匹配到的子图替换为另一个子图 (replacement)。

参数
返回

一个 Match 对象列表,表示原始图中 pattern 匹配到的位置。如果没有任何匹配,则列表为空。 Match 定义为

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]

返回类型

List[Match]

示例

import torch
from torch.fx import symbolic_trace, subgraph_rewriter


class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)


def pattern(w1, w2):
    return torch.cat([w1, w2])


def replacement(w1, w2):
    return torch.stack([w1, w2])


traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上述代码将首先在 traced_moduleforward 方法中匹配 pattern。模式匹配基于使用-定义关系,而不是节点名称。例如,如果您在 pattern 中有 p = torch.cat([a, b]),那么您可以在原始 forward 函数中匹配 m = torch.cat([a, b]),即使变量名不同(p vs m)。

pattern 可调用对象中的 return 语句仅基于其值进行匹配;它可能匹配也可能不匹配到较大图中的 return 语句。换句话说,模式不必延伸到较大图的末尾。

当模式匹配成功时,它将从较大函数中移除并被 replacement 替换。如果在较大函数中有多个 pattern 的匹配项,则每个非重叠的匹配项都将被替换。在匹配重叠的情况下,将替换重叠匹配项集中找到的第一个匹配项。(这里的“第一个”定义为拓扑排序的使用-定义关系中的第一个节点。在大多数情况下,第一个节点是直接出现在 self 之后的参数,而最后一个节点是函数返回的任何内容。)

需要注意的一点是,pattern 可调用对象的参数必须在可调用对象本身中使用,并且 replacement 可调用对象的参数必须与模式匹配。第一个规则解释了为什么在上面的代码块中,forward 函数具有参数 x, w1, w2,而 pattern 函数只有参数 w1, w2pattern 不使用 x,因此不应将 x 指定为参数。作为第二个规则的示例,请考虑替换

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

替换

def replacement(x, y):
    return torch.relu(x)

在这种情况下,replacement 需要与 pattern 相同数量的参数(xy),即使参数 yreplacement 中未使用。

调用 subgraph_rewriter.replace_pattern 后,生成的 Python 代码如下所示:

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

注意

此 API 的向后兼容性已得到保证。