torch.fx#
创建于:2020年12月15日 | 最后更新于:2025年12月05日
概述#
FX 是一个供开发者用于转换 nn.Module 实例的工具包。FX 由三个主要组件组成:符号追踪器 (symbolic tracer),中间表示 (intermediate representation),以及 Python 代码生成 (Python code generation)。这些组件协同工作的演示
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符号追踪器 执行 Python 代码的“符号执行”。它将假的、称为代理 (Proxies) 的值输入代码。对这些代理的操作会被记录下来。更多关于符号追踪的信息可以在 symbolic_trace() 和 Tracer 文档中找到。
中间表示 是符号追踪过程中记录的操作的容器。它包含一系列代表函数输入、调用点(函数、方法或 torch.nn.Module 实例)以及返回值的节点。更多关于 IR 的信息可以在 Graph 的文档中找到。IR 是应用转换的格式。
Python 代码生成 是 FX 成为一个 Python 到 Python(或模块到模块)转换工具包的原因。对于每个 Graph IR,我们可以生成与之语义匹配的有效 Python 代码。此功能封装在 GraphModule 中,它是一个 torch.nn.Module 实例,包含一个 Graph 以及一个从 Graph 生成的 forward 方法。
总而言之,这个组件流水线(符号追踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换流水线。此外,这些组件也可以单独使用。例如,符号追踪可以独立用于捕获代码的一种形式以供分析(而非转换)。代码生成可用于以编程方式生成模型,例如从配置文件生成。FX 有很多用途!
可以在 examples 仓库中找到几个示例转换。
编写转换#
什么是 FX 转换?本质上,它是一个如下所示的函数。
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
您的转换将接收一个 torch.nn.Module,从中获取一个 Graph,进行一些修改,然后返回一个新的 torch.nn.Module。您应该将 FX 转换返回的 torch.nn.Module 视为与普通 torch.nn.Module 相同——您可以将其传递给另一个 FX 转换,或者您可以运行它。确保您的 FX 转换的输入和输出是 torch.nn.Module 将允许组合性。
注意
也可以修改现有的 GraphModule 而不是创建一个新的,如下所示
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
请注意,您必须调用 GraphModule.recompile() 以使 GraphModule 上生成的 forward() 方法与修改后的 Graph 同步。
鉴于您已传入一个已被追踪成 Graph 的 torch.nn.Module,现在有两种主要方法可以用来构建新的 Graph。
图简介#
关于图语义的完整论述可以在 Graph 文档中找到,但我们在此将介绍基础知识。一个 Graph 是一个表示 GraphModule 中某个方法的的数据结构。它所需的信息是
方法的输入是什么?
方法内运行的操作是什么?
方法的输出(即返回值)是什么?
所有这三个概念都用 Node 实例表示。让我们通过一个简短的示例来看看我们是什么意思
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
这里我们为了演示目的定义了一个模块 MyModule,实例化它,符号追踪它,然后调用 Graph.print_tabular() 方法来打印一个表格,显示这个 Graph 的节点
操作码 |
名称 |
目标 |
参数 |
关键字参数 |
|---|---|---|---|---|
占位符 |
x |
x |
() |
{} |
get_attr |
linear_weight |
linear.weight |
() |
{} |
call_function |
add_1 |
(x, linear_weight) |
{} |
|
call_module |
linear_1 |
linear |
(add_1,) |
{} |
call_method |
relu_1 |
relu |
(linear_1,) |
{} |
call_function |
sum_1 |
<built-in method sum …> |
(relu_1,) |
{‘dim’: -1} |
call_function |
topk_1 |
<built-in method topk …> |
(sum_1, 3) |
{} |
output |
output |
output |
(topk_1,) |
{} |
我们可以使用这些信息来回答上面提出的问题。
方法的输入是什么?在 FX 中,方法输入通过特殊的
placeholder节点指定。在这种情况下,我们有一个名为 x 的placeholder节点,其target为x,这意味着我们有一个单独的(非 self)参数 x。方法内的操作是什么?
get_attr、call_function、call_module和call_method节点代表方法中的操作。所有这些操作语义的完整处理可以在Node文档中找到。方法的返回值是什么?
Graph中的返回值由一个特殊的output节点指定。
既然我们现在知道了 FX 中代码表示的基础知识,我们就可以探索如何编辑 Graph 了。
图操作#
直接图操作#
构建新 Graph 的一种方法是直接操作旧图。为了便于此,我们可以简单地获取从符号追踪获得的 Graph 并对其进行修改。例如,假设我们希望用 torch.mul() 调用替换 torch.add() 调用。
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
我们还可以进行更复杂的 Graph 重写,例如删除或追加节点。为了辅助这些转换,FX 在 Graph 文档中提供了用于转换图的实用函数。下面有一个使用这些 API 追加 torch.relu() 调用的示例。
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
对于仅包含替换的简单转换,您还可以利用 子图重写器 (subgraph rewriter)。
使用 replace_pattern() 进行子图重写#
FX 还提供了在直接图操作之上的另一个自动化级别。 replace_pattern() API 本质上是用于编辑 Graph 的“查找/替换”工具。它允许您指定一个 pattern 和一个 replacement 函数,它将追踪这些函数,在 pattern 图中找到操作组的实例,然后用 replacement 图的副本替换这些实例。这有助于极大地自动化繁琐的图操作代码,因为转换变得更复杂时,这些代码可能会变得难以管理。
图操作示例#
代理/重追踪#
操作 Graph 的另一种方法是重用符号追踪中使用的 Proxy 机制。例如,假设我们想编写一个将 PyTorch 函数分解为更小操作的转换。它会将每个 F.relu(x) 调用转换为 (x > 0) * x。一种可能性是执行必要的图重写,在 F.relu 之后插入比较和乘法,然后清理原始的 F.relu。然而,我们可以通过使用 Proxy 对象自动将操作记录到 Graph 中来自动化此过程。
要使用此方法,我们将要插入的操作写成常规的 PyTorch 代码,并使用 Proxy 对象作为参数来调用该代码。这些 Proxy 对象将捕获在其上执行的操作,并将它们追加到 Graph。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式的图操作外,使用 Proxy 还可以让您将重写规则指定为原生 Python 代码。对于需要大量重写规则的转换(例如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。请注意,在调用 Proxy 时,我们还传递了一个指向底层变量 graph 的追踪器。这样做是为了防止如果图中的操作是 n 元的(例如,add 是二元运算符),则调用 Proxy 会创建多个图追踪器实例,这可能导致意外的运行时错误。我们推荐此方法使用 Proxy,尤其是在无法安全地假定底层运算符是单目运算符时。
解释器模式#
FX 中一个有用的代码组织模式是遍历 Graph 中的所有 Node 并执行它们。这可用于多种用途,包括运行时分析流经图的值,或通过使用 Proxy 进行重追踪来转换代码。例如,假设我们想运行一个 GraphModule 并在运行时看到 torch.Tensor 的形状和 dtype 属性时记录它们。这可能看起来像
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
如您所见,FX 的完整解释器并不复杂,但可能非常有用。为了方便使用此模式,我们提供了 Interpreter 类,该类封装了上述逻辑,可以通过方法覆盖来覆盖解释器的某些执行方面。
除了执行操作之外,我们还可以通过将 Proxy 值输入解释器来生成新的 Graph。同样,我们提供了 Transformer 类来封装此模式。Transformer 的行为类似于 Interpreter,但不是调用 run 方法来获取模块的具体输出值,而是调用 Transformer.transform() 方法来返回一个新的 GraphModule,该 GraphModule 受您安装为覆盖方法的任何转换规则的影响。
解释器模式示例#
调试#
简介#
在编写转换的过程中,我们的代码经常不会完全正确。在这种情况下,我们可能需要进行一些调试。关键在于逆向工作:首先,检查生成模块的调用结果,以证明或反驳正确性。然后,检查和调试生成的代码。然后,调试导致生成代码的转换过程。
如果您不熟悉调试器,请参阅辅助部分 可用的调试器。
检查模块的正确性#
由于大多数深度学习模块的输出由浮点 torch.Tensor 实例组成,因此检查两个 torch.nn.Module 结果之间的等价性不像进行简单的相等性检查那样直接。为了说明这一点,让我们看一个例子
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在这里,我们尝试使用 == 相等运算符来检查两个深度学习模型的数值相等性。然而,这并不明确,
不仅因为该运算符返回一个张量而不是布尔值,而且因为浮点值的比较应使用误差范围(或 epsilon)来考虑浮点运算的非交换性(有关更多详细信息,请参见此处)。我们可以改用 torch.allclose(),它将根据相对和绝对容差阈值给出近似比较。
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们工具箱中用于检查转换后的模块是否按预期行为相对于参考实现的第一种工具。
调试生成的代码#
因为 FX 在 GraphModule 上生成 forward() 函数,所以使用传统的调试技术,如 print 语句或 pdb,并没有那么直接。幸运的是,我们有几种技术可用于调试生成的代码。
使用 pdb#
调用 pdb 以进入正在运行的程序。虽然表示 Graph 的代码不在任何源文件中,但我们仍然可以在调用前向传播时手动使用 pdb 进入它。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
打印生成的代码#
如果您想多次运行相同的代码,那么使用 pdb 步入正确的代码可能会有些繁琐。在这种情况下,一种方法是简单地将生成的 forward 传递复制并粘贴到您的代码中,然后从那里进行检查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 GraphModule 中的 to_folder 函数#
GraphModule.to_folder() 是 GraphModule 中的一个方法,允许您将生成的 FX 代码转储到一个文件夹。虽然将前向传递复制到代码中通常就足够了,如 打印生成的代码 中所述,但使用 to_folder 检查模块和参数可能会更容易。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
运行上面的示例后,我们可以查看 foo/module.py 中的代码,并根据需要进行修改(例如,添加 print 语句或使用 pdb)来调试生成的代码。
调试转换#
现在我们已经确定某个转换正在创建不正确的代码,是时候调试转换本身了。首先,我们将查看文档中的 符号追踪的局限性 部分。一旦我们验证了追踪按预期工作,目标就是弄清楚在我们的 GraphModule 转换过程中出了什么问题。也许在 编写转换 中有简单的答案,但如果没有,有几种方法可以检查我们追踪的模块。
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
output output output (add,) {}
"""
使用上面的实用函数,我们可以比较我们追踪的模块在应用转换之前和之后的状态。有时,简单的视觉比较足以找出错误。如果仍然不清楚哪里出了问题,像 pdb 这样的调试器可能是一个不错的下一步。
沿用上面的示例,考虑以下代码
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
使用上面的示例,假设 print(traced) 的调用显示我们的转换存在错误。我们想使用调试器找出问题所在。我们启动一个 pdb 会话。我们可以通过在 transform_graph(traced) 上设置断点,然后按 s “进入”对 transform_graph(traced) 的调用来查看转换过程中发生了什么。
我们也可以尝试编辑 print_tabular 方法来打印 Graph 中 Nodes 的不同属性。(例如,我们可能想查看 Node 的 input_nodes 和 users。)
可用的调试器#
最常见的 Python 调试器是 pdb。您可以通过在命令行中键入 python -m pdb FILENAME.py 来以“调试模式”启动程序,其中 FILENAME 是您要调试的文件名。之后,您可以使用 pdb 的 调试器命令 来逐步执行正在运行的程序。在启动 pdb 时设置断点(b LINE-NUMBER),然后调用 c 来运行程序直到该点,这是很常见的。这可以避免您必须逐行执行(使用 s 或 n)来找到您想检查的代码部分。或者,您可以在要中断的行之前编写 import pdb; pdb.set_trace()。如果您添加了 pdb.set_trace(),当您运行程序时,它将自动进入调试模式。(换句话说,您只需在命令行中键入 python FILENAME.py 而不是 python -m pdb FILENAME.py。)一旦您在调试模式下运行文件,您就可以使用某些命令逐步执行代码并检查程序的内部状态。在线上有很多关于 pdb 的精彩教程,包括 RealPython 的 “Python Debugging With Pdb”。
像 PyCharm 或 VSCode 这样的 IDE 通常内置调试器。在您的 IDE 中,您可以选择 a) 通过在 IDE 中打开一个终端窗口(例如 VSCode 中的 View → Terminal)来使用 pdb,或者 b) 使用内置调试器(通常是 pdb 的图形化包装器)。
符号追踪的局限性#
FX 使用一个符号追踪(又名 符号执行)系统,以可转换/可分析的形式捕获程序的语义。该系统是追踪的,因为它会执行程序(实际上是 torch.nn.Module 或函数)来记录操作。它是符号化的,因为在这个执行过程中流经程序的数据不是真实数据,而是符号(在 FX 术语中是 Proxy)。
尽管符号追踪适用于大多数神经网络代码,但它也有一些局限性。
动态控制流#
符号追踪的主要局限性是它目前不支持动态控制流。也就是说,它不支持依赖于程序输入值的循环或 if 语句。
例如,我们来看下面的程序
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() > 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if 语句的条件依赖于 x.sum() 的值,而 x.sum() 又依赖于函数输入 x 的值。由于 x 会发生变化(即,如果您将新的输入张量传递给被追踪的函数),这就是动态控制流。回溯会沿着代码向上追踪,向您展示这种情况发生的位置。
静态控制流#
另一方面,所谓的静态控制流是受支持的。静态控制流是指其值在调用之间不会改变的循环或 if 语句。通常,在 PyTorch 程序中,这种控制流出现在基于超参数决定模型架构的代码中。一个具体的例子是
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if self.do_activation 这个 if 语句不依赖于任何函数输入,因此是静态的。 do_activation 可以被视为一个超参数,不同 MyModule 实例,即使该参数的值不同,它们的追踪代码也是不同的。这是一个有效的模式,并且受符号追踪支持。
许多动态控制流的实例实际上是静态控制流。这些实例可以通过消除对输入值的依赖来支持符号追踪,例如,将值移至 Module 属性,或在符号追踪期间将具体值绑定到参数。
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
对于真正动态的控制流,包含此代码的程序部分可以被追踪为对 Method(参见 使用 Tracer 类自定义追踪)或函数(参见 wrap())的调用,而不是对其进行追踪。
非 torch 函数#
FX 使用 __torch_function__ 作为拦截调用的机制(有关更多信息,请参阅 技术概述)。一些函数,例如内置 Python 函数或 math 模块中的函数,不被 __torch_function__ 覆盖,但我们仍然希望在符号追踪中捕获它们。例如
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
错误告诉我们内置函数 len 不受支持。我们可以使用 wrap() API 使像这样的函数被记录为对它们的直接调用,以包含在追踪中。
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用 Tracer 类自定义追踪#
Tracer 类是 symbolic_trace 实现的基础。可以通过继承 Tracer 类来定制追踪行为,如下所示:
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
叶子模块#
叶子模块是在符号追踪中显示为调用的模块,而不是被追踪的模块。默认的叶子模块集合是标准的 torch.nn 模块实例集合。例如:
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
可以通过重写 Tracer.is_leaf_module() 来定制叶子模块的集合。
杂项#
张量构造函数(例如
torch.zeros、torch.ones、torch.rand、torch.randn、torch.sparse_coo_tensor)目前不可追踪。确定性构造函数(
zeros、ones)可以使用,并且它们产生的值将被嵌入到追踪中作为常量。只有当这些构造函数的参数引用动态输入大小时才会出现问题。在这种情况下,ones_like或zeros_like可能是可行的替代方案。非确定性构造函数(
rand、randn)将在追踪中嵌入一个随机值。这可能不是预期的行为。一个解决方法是将torch.randn包装在torch.fx.wrap函数中,然后调用该函数。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
此行为可能在将来的版本中修复。
类型注解
Python 3 风格的类型注解(例如
func(x : torch.Tensor, y : int) -> torch.Tensor)是受支持的,并且将被符号追踪保留。Python 2 风格的注释类型注解
# type: (torch.Tensor, int) -> torch.Tensor目前不受支持。函数内部局部名称上的注解目前不受支持。
training标志和子模块的陷阱当使用
torch.nn.functional.dropout等函数式 API 时,通常会将 training 参数作为self.training传入。在 FX 追踪期间,这很可能会被硬编码为一个常量值。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! Mismatched elements: 15 / 15 (100.0%) Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) """
然而,当使用标准的
nn.Dropout()子模块时,training 标志被封装起来,并且由于nn.Module对象模型的保留,可以被改变。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
由于这种差异,请考虑将与
training标志动态交互的模块标记为叶子模块。
API参考#
- torch.fx.symbolic_trace(root, concrete_args=None)[source]#
符号追踪 API
给定一个
nn.Module或函数实例root,此函数将返回一个GraphModule,该GraphModule是通过记录在追踪root过程中看到的运算来构建的。concrete_args允许您部分特化函数,无论是为了消除控制流还是数据结构。例如
def f(a, b): if b == True: return a else: return a * 2
由于控制流的存在,FX 通常无法追踪它。但是,我们可以使用concrete_args 来特化b 的值,从而进行追踪。
f = fx.symbolic_trace(f, concrete_args={"b": False}) assert f(3, False) == 6
请注意,虽然您仍然可以传入 b 的不同值,但它们将被忽略。
我们还可以使用 concrete_args 来消除函数中的数据结构处理。这将使用 pytrees 来展平您的输入。为了避免过度特化,请为不应特化的值传入
fx.PH。例如:def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace( f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}} ) assert f({"a": 1, "b": 2, "c": 4}) == 7
- 参数:
root (Union[torch.nn.Module, Callable]) – 要被追踪并转换为 Graph 表示的模块或函数。
concrete_args (Optional[Dict[str, any]]) – 要部分特化的输入。
- 返回:
一个由
root记录的操作构建而成的模块。- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- torch.fx.wrap(fn_or_name)[source]#
此函数可以在模块级别作用域调用,以将 fn_or_name 注册为“叶子函数”。“叶子函数”将在 FX 追踪中被保留为一个 CallFunction 节点,而不是被追踪。
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap("my_custom_function") def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
此函数也可以等效地用作装饰器:
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
被包装的函数可以被视为“叶子函数”,类似于“叶子模块”的概念,也就是说,它们是在 FX 追踪中被保留为调用的函数,而不是被追踪。
- 参数:
fn_or_name (Union[str, Callable]) – 要在图表中插入的函数或全局函数的名称。
注意
此 API 的向后兼容性已得到保证。
- class torch.fx.GraphModule(*args, **kwargs)[source]#
GraphModule 是一个从 fx.Graph 生成的 nn.Module。Graphmodule 具有
graph属性,以及从该graph生成的code和forward属性。警告
当
graph被重新赋值时,code和forward将被自动重新生成。但是,如果您在未重新赋值graph属性本身的情况下编辑graph的内容,则必须调用recompile()来更新生成的代码。注意
此 API 的向后兼容性已得到保证。
- __init__(root, graph, class_name='GraphModule')[source]#
构造一个 GraphModule。
- 参数:
root (Union[torch.nn.Module, Dict[str, Any]) –
root可以是 nn.Module 实例,也可以是映射字符串到任何属性类型的 Dict。如果root是一个 Module,则 Graph 的 Nodes 的target字段中对 Module 对象(通过限定名称)的任何引用都将从root的 Module 层次结构中的相应位置复制到 GraphModule 的模块层次结构中。如果root是一个 dict,则在 Node 的target中找到的限定名称将直接在 dict 的键中查找。由 Dict 映射的对象将被复制到 GraphModule 的模块层次结构中的适当位置。graph (Graph) –
graph包含此 GraphModule 应用于代码生成的节点。class_name (str) –
name指示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息都将报告为源自GraphModule。将此设置为root的原始名称或在您的转换上下文中具有意义的名称可能很有用。
注意
此 API 的向后兼容性已得到保证。
- add_submodule(target, m)[source]#
将给定的子模块添加到
self。如果
target的子路径上尚无模块,则此操作将在此处安装空模块。- 参数:
- 返回:
- 子模块是否能够被插入。对于
此方法返回 True,
target所表示的链中的每个对象都必须:a) 尚不存在,或 b) 引用一个nn.Module(而不是参数或其他属性)。
- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- delete_all_unused_submodules()[source]#
从
self中删除所有未使用的子模块。当以下任一情况为真时,模块被认为是“已使用”:1. 它拥有已使用的子模块。2. 它的 forward 方法直接通过
call_module节点调用。3. 它有一个非 Module 属性,该属性是从get_attr节点使用的。可以调用此方法来清理
nn.Module,而无需手动调用delete_submodule来删除每个未使用的子模块。注意
此 API 的向后兼容性已得到保证。
- delete_submodule(target)[source]#
从
self中删除给定的子模块。如果
target不是有效目标,则模块不会被删除。- 参数:
target (str) – 新子模块的完全限定字符串名称(请参阅
nn.Module.get_submodule中的示例,了解如何指定完全限定字符串。)- 返回:
- 目标字符串是否引用了一个
我们要删除的子模块。返回
False表示target不是对子模块的有效引用。
- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- print_readable(print_output=True, include_stride=False, include_device=False, colored=False, *, fast_sympy_print=False, expanded_def=False)[source]#
返回为当前 GraphModule 及其子 GraphModule 生成的 Python 代码。
警告
此 API 是实验性的,并且不向后兼容。
- class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#
Graph是 FX 中间表示中使用的主要数据结构。它由一系列Node组成,每个Node代表一个调用点(或其他语法构造)。Node的列表共同构成了一个有效的 Python 函数。例如,以下代码
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk( torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 ) m = MyModule() gm = torch.fx.symbolic_trace(m)
将产生以下 Graph
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1有关
Graph中操作的语义,请参阅Node。注意
此 API 的向后兼容性已得到保证。
- __init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#
构造一个空的 Graph。
注意
此 API 的向后兼容性已得到保证。
- call_function(the_function, args=None, kwargs=None, type_expr=None, name=None)[source]#
将
call_functionNode插入到Graph中。call_function节点表示对由the_function指定的 Python 可调用对象的调用。- 参数:
the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 运算符、Python 函数或
builtins或operator命名空间中的成员。args (Optional[Tuple[Argument, ...]]) – 要传递给被调用函数的 positional 参数。
kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用函数的关键字参数。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
name (Optional[str]) – 节点的名称。如果未指定,则设置为 None。
- 返回:
新创建并插入的
call_function节点。- 返回类型:
注意
此方法的插入点和类型表达式规则与
Graph.create_node()相同。注意
此 API 的向后兼容性已得到保证。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[source]#
将
call_methodNode插入到Graph中。call_method节点表示对args的第 0 个元素上的给定方法进行调用。- 参数:
method_name (str) – 要应用于 self 参数的方法的名称。例如,如果 args[0] 是表示
Tensor的Node,则要在该Tensor上调用relu(),则将relu传递给method_name。args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的 positional 参数。请注意,这应该包含一个
self参数。kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的关键字参数。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回:
新创建并插入的
call_method节点。- 返回类型:
注意
此方法的插入点和类型表达式规则与
Graph.create_node()相同。注意
此 API 的向后兼容性已得到保证。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[source]#
将
call_moduleNode插入到Graph中。call_module节点表示对Module层次结构中的Module的forward()函数的调用。- 参数:
module_name (str) – 要调用的
Module层次结构中Module的限定名称。例如,如果跟踪的Module有一个名为foo的子模块,该子模块有一个名为bar的子模块,则应将限定名称foo.bar作为module_name传递以调用该模块。args (Optional[Tuple[Argument, ...]]) – 将传递给被调用方法的 positional arguments。请注意,这 *不* 应该包含
self参数。kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的关键字参数。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回:
新创建并插入的
call_module节点。- 返回类型:
注意
此方法的插入点和类型表达式规则与
Graph.create_node()相同。注意
此 API 的向后兼容性已得到保证。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source]#
创建一个
Node并将其添加到当前插入点的Graph中。请注意,当前插入点可以通过Graph.inserting_before()和Graph.inserting_after()设置。- 参数:
op (str) – 此 Node 的操作码。可以是 'call_function'、'call_method'、'get_attr'、'call_module'、'placeholder' 或 'output'。这些操作码的语义在
Graph文档字符串中进行了描述。args (Optional[Tuple[Argument, ...]]) – 此节点参数的元组。
kwargs (Optional[Dict[str, Argument]]) – 此 Node 的 kwargs。
name (Optional[str]) –
Node的可选字符串名称。这将影响生成的 Python 代码中分配给该值的名称。type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回:
新创建并插入的节点。
- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- eliminate_dead_code(is_impure_node=None)[source]#
根据每个节点的用法数量以及节点是否具有任何副作用,从图中删除所有死代码。调用前必须对图进行拓扑排序。
- 参数:
- 返回:
由于该过程,图是否已更改。
- 返回类型:
示例
在消除死代码之前,下面的 a = x + 1 中的 a 没有用法,因此可以从图中消除而不会产生任何影响。
def forward(self, x): a = x + 1 return x + self.attr_1
消除死代码后,a = x + 1 已被删除,而 forward 的其余部分保留。
def forward(self, x): return x + self.attr_1
警告
死代码消除有一些启发式方法可以避免删除具有副作用的节点(请参阅 Node.is_impure),但总的来说覆盖率非常差,因此您应该假定调用此方法不安全,除非您知道您的 FX 图完全由函数式操作组成,或者您提供了自己的自定义函数来检测具有副作用的节点。
注意
此 API 的向后兼容性已得到保证。
- erase_node(to_erase)[source]#
从
Graph中擦除一个Node。如果Graph中仍有该节点的用法,则会引发异常。- 参数:
to_erase (Node) – 要从
Graph中擦除的Node。
注意
此 API 的向后兼容性已得到保证。
- find_nodes(*, op, target=None, sort=True)[source]#
允许快速查询节点
- 参数:
- 返回:
具有请求的 op 和 target 的节点的可迭代对象。
警告
此 API 是实验性的,并且不向后兼容。
- get_attr(qualified_name, type_expr=None)[source]#
将
get_attr节点插入到 Graph 中。get_attrNode表示从Module层次结构中获取属性。- 参数:
qualified_name (str) – 要检索的属性的完全限定名称。例如,如果跟踪的 Module 有一个名为
foo的子模块,该子模块有一个名为bar的子模块,该子模块有一个名为baz的属性,则应将限定名称foo.bar.baz作为qualified_name传递。type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回:
新创建并插入的
get_attr节点。- 返回类型:
注意
此方法与
Graph.create_node具有相同的插入点和类型表达式规则。注意
此 API 的向后兼容性已得到保证。
- graph_copy(g, val_map, return_output_node=False)[source]#
将给定图中的所有节点复制到
self中。- 参数:
- 返回:
self中现在等价于g中输出值的那个值,如果g有一个output节点的话。否则为None。- 返回类型:
tuple[Argument, …] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None
注意
此 API 的向后兼容性已得到保证。
- inserting_after(n=None)[source]#
- 设置 create_node 和辅助方法将插入图中的点。
当在 'with' 语句中使用时,它将临时设置插入点,然后在 with 语句退出时恢复插入点。
with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
参数 (Args)
- n (Optional[Node]): 要在其之前插入的节点。如果为 None,则将在图的开头之后插入。
图的开头。
- 返回
一个资源管理器,它将在
__exit__时恢复插入点。
注意
此 API 的向后兼容性已得到保证。
- inserting_before(n=None)[source]#
- 设置 create_node 和辅助方法将插入图中的点。
当在 'with' 语句中使用时,它将临时设置插入点,然后在 with 语句退出时恢复插入点。
with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
参数 (Args)
- n (Optional[Node]): 要在其之前插入的节点。如果为 None,则将在图的开头之前插入。
图的开头。
- 返回
一个资源管理器,它将在
__exit__时恢复插入点。
注意
此 API 的向后兼容性已得到保证。
- lint()[source]#
运行此 Graph 的各种检查以确保其格式正确。特别是:- 检查节点是否具有正确的归属(属于此图)- 检查节点是否按拓扑顺序出现- 如果此 Graph 具有所有者 GraphModule,则检查 Target 是否存在于该 GraphModule 中
注意
此 API 的向后兼容性已得到保证。
- node_copy(node, arg_transform=<function Graph.<lambda>>)[source]#
将一个节点从一个图复制到另一个图。
arg_transform需要将参数从 node 的图转换到 self 的图。示例# Copying all the nodes in `g` into `new_graph` g: torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
- 参数:
- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- property nodes: _node_list#
获取构成此 Graph 的 Nodes 列表。
注意,此
Node列表表示一个双向链表。在迭代过程中进行修改(例如,删除一个 Node,添加一个 Node)是安全的。- 返回:
Nodes 的双向链表。请注意,可以对此列表调用
reversed来切换迭代顺序。
- on_generate_code(make_transformer)[source]#
在生成 python 代码时注册一个转换器函数
- 参数 (Args)
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc])
一个返回要注册的代码转换器的函数。此函数由 on_generate_code 调用以获取代码转换器。
此函数还接收当前注册的代码转换器(如果未注册则为 None)作为输入,以防不希望覆盖它。这对于链式处理代码转换器很有用。
- 返回
一个上下文管理器,当在 with 语句中使用时,可以自动恢复以前注册的代码转换器。
示例
gm: fx.GraphModule = ... # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual # debugging with the PDB library. def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): gm.graph.on_generate_code(lambda _: insert_pdb) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb( current_trans(body) if current_trans else body ) ) ) gm.recompile() gm(*inputs) # drops into pdb
此函数还可以用作上下文管理器,其好处是可以自动恢复以前注册的代码转换器。
# ... continue from previous example with gm.graph.on_generate_code(lambda _: insert_pdb): # do more stuff with `gm`... gm.recompile() gm(*inputs) # drops into pdb # now previous code transformer is restored (but `gm`'s code with pdb # remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告
此 API 是实验性的,并且不向后兼容。
- output(result, type_expr=None)[source]#
将
outputNode插入到Graph中。output节点表示 Python 代码中的return语句。result是应返回的值。- 参数:
result (Argument) – 要返回的值。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
注意
此方法与
Graph.create_node具有相同的插入点和类型表达式规则。注意
此 API 的向后兼容性已得到保证。
- placeholder(name, type_expr=None, default_value)[source]#
将
placeholder节点插入到 Graph 中。placeholder表示函数输入。- 参数:
name (str) – 输入值的名称。这对应于此
Graph所表示函数的 positional argument 的名称。type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出将具有的 Python 类型。在某些情况下,这对于正确的代码生成是必需的(例如,当函数随后在 TorchScript 编译中使用时)。
default_value (Any) – 此函数参数应采用的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 作为此参数传递,以指定参数 *不* 具有默认值。
- 返回类型:
注意
此方法与
Graph.create_node具有相同的插入点和类型表达式规则。注意
此 API 的向后兼容性已得到保证。
- python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False, expanded_def=False, record_func=False)[source]#
将此
Graph转换为有效的 Python 代码。- 参数:
root_module (str) – 用于查找限定名称 target 的根模块的名称。通常是 'self'。
- 返回:
src:表示对象的 Python 源代码 globals:src 中的全局名称字典 -> 它们引用的对象。
- 返回类型:
一个 PythonCode 对象,包含两个字段
注意
此 API 的向后兼容性已得到保证。
- class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source]#
Node是表示Graph中单个操作的数据结构。在大多数情况下,Nodes 表示对各种实体的调用点,例如运算符、方法和 Modules(一些例外包括指定函数输入和输出的节点)。每个Node都有一个由其op属性指定的函数。每个op值的Node语义如下:placeholder表示函数输入。name属性指定该值将具有的名称。target类似地是参数的名称。args持有以下之一:1) 无,或 2) 单个参数,表示函数输入的默认参数。kwargs是不关心的。Placeholder 对应于图打印输出中的函数参数(例如x)。get_attr从模块层次结构中检索参数。name类似地是提取结果被赋值的名称。target是参数在模块层次结构中位置的完全限定名称。args和kwargs是不关心的。call_function应用一个自由函数到一些值。name类似地是分配给该值。target是要应用的函数。args和kwargs表示函数的参数,遵循 Python 调用约定。call_module将模块层次结构中的模块的forward()方法应用于给定的参数。name与前述相同。target是要调用的模块在模块层次结构中的完全限定名称。args和kwargs表示调用模块的参数,*不包括 self 参数*。call_method调用一个值上的方法。name类似。target是要应用于self参数的方法的字符串名称。args和kwargs表示调用模块的参数,*包括 self 参数*。output在其args[0]属性中包含被跟踪函数的输出。这对应于图打印输出中的“return”语句。
注意
此 API 的向后兼容性已得到保证。
- property all_input_nodes: list[Node]#
返回作为此 Node 输入的所有 Nodes。这等效于迭代
args和kwargs并仅收集是 Nodes 的值。- 返回:
出现在此
Node的args和kwargs中的Nodes列表,按此顺序。
- append(x)[source]#
在图的节点列表中在此节点之后插入
x。等效于self.next.prepend(x)。- 参数:
x (Node) – 要放在此节点之后的节点。必须是同一图的成员。
注意
此 API 的向后兼容性已得到保证。
- property args: tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...]#
该
Node的参数元组。参数的解释取决于节点的 opcode。有关更多信息,请参阅Node文档字符串。允许为此属性赋值。所有使用和用户计数将在赋值时自动更新。
- format_node(placeholder_names=None, maybe_return_typename=None, *, include_tensor_metadata=False)[source]#
返回
self的描述性字符串表示。此方法可不带参数地用作调试工具。
此函数也用于
Graph的__str__方法的内部。placeholder_names和maybe_return_typename中的字符串共同构成了此 Graph 外部 GraphModule 中 autogeneratedforward函数的签名。placeholder_names和maybe_return_typename不得用于其他目的。- 参数:
- 返回:
- 如果 1) 我们将
format_node用作内部帮助程序 在
Graph的__str__方法中,并且 2)self是一个占位符 Node,则返回None。否则,返回当前 Node 的描述性字符串表示。
- 如果 1) 我们将
- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- insert_arg(idx, arg)[source]#
在指定索引处向参数列表插入一个位置参数。
- 参数:
idx (int) – 要在
self.args中插入元素的索引。arg (Argument) – 要插入到
args中的新参数值。
注意
此 API 的向后兼容性已得到保证。
- is_impure(impure_random=True)[source]#
返回此操作是否为不纯操作,即,如果其操作是占位符或输出,或者是一个不纯的 call_function 或 call_module。
警告
此 API 是实验性的,并且不向后兼容。
- property kwargs: dict[str, tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None]#
该
Node的关键字参数字典。参数的解释取决于节点的 opcode。有关更多信息,请参阅Node文档字符串。允许为此属性赋值。所有使用和用户计数将在赋值时自动更新。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source]#
返回 Python 目标的标准化参数。这意味着 args/kwargs 将与模块/函数的签名匹配,并且如果 normalize_to_only_use_kwargs 为 true,则仅返回关键字参数。还填充默认值。不支持仅位置参数或可变参数。
支持模块调用。
可能需要 arg_types 和 kwarg_types 来区分重载。
- 参数:
root (torch.nn.Module) – 用于解析模块目标的模块。
arg_types (Optional[Tuple[Any]]) – 参数的类型元组。
kwarg_types (Optional[Dict[str, Any]]) – 关键字参数的类型字典。
normalize_to_only_use_kwargs (bool) – 是否标准化为仅使用关键字参数。
- 返回:
返回 NamedTuple ArgsKwargsPair,如果未成功则返回 None。
- 返回类型:
ArgsKwargsPair | None
警告
此 API 是实验性的,并且不向后兼容。
- prepend(x)[source]#
在图中节点列表中在此节点之前插入 x。示例
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- 参数:
x (Node) – 要放在此节点之前的节点。必须是同一图的成员。
注意
此 API 的向后兼容性已得到保证。
- replace_all_uses_with(replace_with, delete_user_cb=None, *, propagate_meta=False)[source]#
将
self在 Graph 中的所有使用替换为 Nodereplace_with。- 参数:
- 返回:
在此更改上执行操作的 Node 列表。
- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- replace_input_with(old_input, new_input)[source]#
遍历
self的输入节点,并将所有old_input的实例替换为new_input。注意
此 API 的向后兼容性已得到保证。
- property stack_trace: str | None#
返回跟踪期间记录的 Python 堆栈跟踪(如果有)。使用 fx.Tracer 进行跟踪时,此属性通常由 Tracer.create_proxy 填充。要为调试目的在跟踪期间记录堆栈跟踪,请在 Tracer 实例上设置 record_stack_traces = True。使用 dynamo 进行跟踪时,此属性将由 OutputGraph.create_proxy 默认填充。
stack_trace 的字符串末尾将是最内层帧。
- class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source]#
Tracer是实现torch.fx.symbolic_trace的符号跟踪功能的类。调用symbolic_trace(m)等同于Tracer().trace(m)。Tracer 可以被子类化以覆盖跟踪过程的各种行为。可覆盖的不同行为在此类的文档字符串中进行了描述。
注意
此 API 的向后兼容性已得到保证。
- call_module(m, forward, args, kwargs)[source]#
指定此
Tracer在遇到对nn.Module实例的调用时行为的方法。默认情况下,行为是检查调用模块是否是叶子模块(通过
is_leaf_module)。如果是,则在Graph中发出指向m的call_module节点。否则,正常调用Module,跟踪其forward函数中的操作。此方法可以被覆盖,例如,创建嵌套的跟踪 GraphModules,或在跨越
Module边界进行跟踪时所需的任何其他行为。- 参数:
m (Module) – 要发出调用的模块。
forward (Callable) – 要调用的
Module的 forward() 方法。args (Tuple) – 模块调用点的参数。
kwargs (Dict) – 模块调用点的关键字参数。
- 返回:
模块调用的返回值。如果发出了
call_module节点,则此值为Proxy值。否则,它将是模块调用返回的任何值。- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- create_arg(a)[source]#
用于指定跟踪在准备用作
Graph中节点参数的值时的行为的方法。默认情况下,行为包括:
迭代集合类型(例如,tuple、list、dict),并递归调用
create_args处理元素。给定一个 Proxy 对象,返回对底层 IR
Node的引用。给定一个非 Proxy Tensor 对象,发出各种情况的 IR。
对于 Parameter,发出指向该 Parameter 的
get_attr节点。对于非 Parameter Tensor,将该 Tensor 存储在一个特殊属性中,该属性指向该属性。
此方法可以被重写以支持更多类型。
- 参数:
a (Any) – 将被发出为
Graph中Argument的值。- 返回:
已转换为适当
Argument的值a- 返回类型:
Argument
注意
此 API 的向后兼容性已得到保证。
- create_args_for_root(root_fn, is_module, concrete_args=None)[source]#
创建与
root模块的签名对应的placeholder节点。此方法会内省 root 的签名并据此发出这些节点,同时也支持*args和**kwargs。警告
此 API 是实验性的,并且不向后兼容。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]#
给定目标、参数、关键字参数和名称,插入一个图节点。
此方法可以被重写,以进行额外的检查、验证或修改用于节点创建的值。例如,您可能希望禁止记录就地操作。
注意
此 API 的向后兼容性已得到保证。
- 返回类型:
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]#
根据给定的参数创建一个节点,然后返回包装在 Proxy 对象中的节点。
如果 kind = ‘placeholder’,那么我们正在创建一个代表函数参数的节点。如果需要编码默认参数,我们使用
args元组。对于placeholder节点,args否则为空。注意
此 API 的向后兼容性已得到保证。
- getattr(attr, attr_val, parameter_proxy_cache)[source]#
当调用
nn.Module实例的调用时,指定此Tracer行为的方法。默认情况下,行为是为属性返回一个代理值。它还将代理值存储在
parameter_proxy_cache中,以便将来的调用将重用代理而不是创建新的代理。此方法可以被重写,例如,在查询参数时返回代理。
- 参数:
attr (str) – 被查询属性的名称
attr_val (Any) – 属性的值
parameter_proxy_cache (Dict[str, Any]) – 属性名到代理的缓存
- 返回:
getattr 调用返回的值。
警告
此 API 是实验性的,并且不向后兼容。
- is_leaf_module(m, module_qualified_name)[source]#
指定给定
nn.Module是否为“叶”模块的方法。叶模块是在 IR 中出现的原子单元,由
call_module调用引用。默认情况下,PyTorch 标准库命名空间 (torch.nn) 中的模块是叶模块。所有其他模块都将被跟踪并通过,其组成的操作将被记录,除非通过此参数另有指定。- 参数:
m (Module) – 被查询的模块
module_qualified_name (str) – 到此模块根的路径。例如,如果您有一个模块层次结构,其中子模块
foo包含子模块bar,它包含子模块baz,该模块将在此处显示为合格名称foo.bar.baz。
- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- iter(obj)[source]#
- 当代理对象被迭代时调用,例如
在控制流中使用时。通常我们不知道该怎么做,因为我们不知道代理的值,但是自定义跟踪器可以通过 create_node 将更多信息附加到图节点,并可以选择返回一个迭代器。
注意
此 API 的向后兼容性已得到保证。
- 返回类型:
- keys(obj)[source]#
- 当代理对象调用 keys() 方法时调用。
当在代理上调用 ** 时会发生这种情况。如果 ** 旨在为您的自定义跟踪器工作,则应返回一个迭代器。
注意
此 API 的向后兼容性已得到保证。
- 返回类型:
- path_of_module(mod)[source]#
用于查找
mod在root的模块层次结构中的合格名称的帮助方法。例如,如果root有一个名为foo的子模块,它有一个名为bar的子模块,将bar传递给此函数将返回字符串“foo.bar”。- 参数:
mod (str) – 要为其检索合格名称的
Module。- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- to_bool(obj)[source]#
- 当代理对象被转换为布尔值时调用,例如
在控制流中使用时。通常我们不知道该怎么做,因为我们不知道代理的值,但是自定义跟踪器可以通过 create_node 将更多信息附加到图节点,并可以选择返回一个值。
注意
此 API 的向后兼容性已得到保证。
- 返回类型:
- trace(root, concrete_args=None)[source]#
跟踪
root并返回相应的 FXGraph表示。root可以是nn.Module实例或 Python 可调用对象。请注意,在此调用之后,
self.root可能与此处传递的root不同。例如,当一个自由函数传递给trace()时,我们将创建一个nn.Module实例作为根并添加嵌入的常量。- 参数:
root (Union[Module, Callable]) – 要跟踪的
Module或函数。此参数的向后兼容性已保证。concrete_args (Optional[Dict[str, any]]) – 不应被视为代理的具体参数。此参数是实验性的,其向后兼容性不保证。
- 返回:
表示传入的
root语义的Graph。- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- class torch.fx.Proxy(node, tracer=None)[source]#
Proxy对象是Node包装器,它们在符号跟踪期间流经程序,并将它们触及的所有操作(torch函数调用、方法调用、运算符)记录到不断增长的 FX 图中。如果您正在进行图转换,您可以将自己的
Proxy方法包装到原始Node上,以便可以使用重载的运算符将其他内容添加到Graph中。Proxy对象不能被迭代。换句话说,符号跟踪器将在循环中使用Proxy或将其用作*args/**kwargs函数参数时抛出错误。有两种主要方法可以解决此问题:1. 将不可跟踪的逻辑提取到顶级函数中,并对其使用
fx.wrap。2. 如果控制流是静态的(即循环的跳数基于某个超参数),则代码可以保留在原始位置并重构为如下内容:for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
有关 Proxy 内部机制的更详细描述,请参阅 torch/fx/README.md 中的“Proxy”部分。
注意
此 API 的向后兼容性已得到保证。
- class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[source]#
Interpreter 按节点逐个执行 FX 图。此模式可用于许多用途,包括编写代码转换和分析传递。
Interpreter 类中的方法可以被重写以自定义执行行为。可重写方法的调用层次结构图
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
示例
假设我们想将所有
torch.neg实例与torch.sigmoid相互替换(包括它们的Tensor方法等价物)。我们可以像这样子类化 Interpreter:class NegSigmSwapInterpreter(Interpreter): def call_function( self, target: Target, args: Tuple, kwargs: Dict ) -> Any: if target is torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid())
- 参数:
module (torch.nn.Module) – 要执行的模块
garbage_collect_values (bool) – 是否在模块执行期间的最后使用后删除值。这可确保执行期间的最佳内存使用。可以禁用此选项,例如,通过查看
Interpreter.env属性来检查所有中间值。graph (Optional[Graph]) – 如果传入,解释器将执行此图而不是 module.graph,并使用提供的 module 参数来满足任何状态请求。
注意
此 API 的向后兼容性已得到保证。
- boxed_run(args_list)[source]#
通过解释运行 module 并返回结果。这使用了“boxed”调用约定,其中您传递一个参数列表,该列表将被解释器清除。这确保输入张量能被及时取消分配。
注意
此 API 的向后兼容性已得到保证。
- call_function(target, args, kwargs)[source]#
执行
call_function节点并返回结果。- 参数:
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型:
- Return
Any: 函数调用返回的值
注意
此 API 的向后兼容性已得到保证。
- call_method(target, args, kwargs)[source]#
执行
call_method节点并返回结果。- 参数:
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型:
- Return
Any: 方法调用返回的值
注意
此 API 的向后兼容性已得到保证。
- call_module(target, args, kwargs)[source]#
执行
call_module节点并返回结果。- 参数:
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型:
- Return
Any: 模块调用返回的值
注意
此 API 的向后兼容性已得到保证。
- fetch_args_kwargs_from_env(n)[source]#
从当前执行环境中获取节点
n的args和kwargs的具体值。- 参数:
n (Node) – 应从中获取
args和kwargs的节点。- 返回:
n的带有具体值的args和kwargs。- 返回类型:
Tuple[Tuple, Dict]
注意
此 API 的向后兼容性已得到保证。
- fetch_attr(target)[source]#
从
self.module的模块层次结构中获取一个属性。- 参数:
target (str) – 要获取的属性的完全限定名称
- 返回:
属性的值。
- 返回类型:
任何
注意
此 API 的向后兼容性已得到保证。
- get_attr(target, args, kwargs)[source]#
执行
get_attr节点。将从self.module的模块层次结构中检索属性值。- 参数:
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回:
检索到的属性值
- 返回类型:
任何
注意
此 API 的向后兼容性已得到保证。
- map_nodes_to_values(args, n)[source]#
递归地遍历
args中的数据结构,并在当前执行环境中查找每个Node的具体值。- 参数:
args (Argument) – 用于查找具体值的结构
n (Node) –
args所属的节点。仅用于错误报告。
- 返回类型:
tuple[Argument, …] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None
注意
此 API 的向后兼容性已得到保证。
- output(target, args, kwargs)[source]#
执行
output节点。这实际上只是检索由output节点引用的值并返回它。- 参数:
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回:
输出节点引用的返回值
- 返回类型:
任何
注意
此 API 的向后兼容性已得到保证。
- placeholder(target, args, kwargs)[source]#
执行
placeholder节点。注意这是有状态的:Interpreter维护一个传递给run的参数的内部迭代器,此方法在此迭代器上调用 next()。- 参数:
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回:
检索到的参数值。
- 返回类型:
任何
注意
此 API 的向后兼容性已得到保证。
- run(*args, initial_env=None, enable_io_processing=True)[source]#
通过解释运行 module 并返回结果。
- 参数:
*args – 要运行的模块的参数,按位置顺序排列
initial_env (Optional[Dict[Node, Any]]) – 一个可选的执行起始环境。这是一个映射 Node 到任何值的字典。例如,这可用于预填充某些 Nodes 的结果,以便仅在解释器中进行部分求值。
enable_io_processing (bool) – 如果为 true,我们将在使用它们之前先使用 graph 的 process_inputs 和 process_outputs 函数来处理输入和输出。
- 返回:
执行模块返回的值
- 返回类型:
任何
注意
此 API 的向后兼容性已得到保证。
- class torch.fx.Transformer(module)[source]#
Transformer是一种特殊的解释器,它生成一个新的Module。它公开了一个transform()方法,该方法返回转换后的Module。Transformer不需要参数即可运行,而Interpreter则需要。Transformer完全以符号方式工作。示例
假设我们想将所有
torch.neg实例与torch.sigmoid相互替换(包括它们的Tensor方法等价物)。我们可以像这样子类化Transformer:class NegSigmSwapXformer(Transformer): def call_function( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any], ) -> Any: if target is torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any], ) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- 参数:
module (GraphModule) – 要转换的
Module。
注意
此 API 的向后兼容性已得到保证。
- get_attr(target, args, kwargs)[source]#
执行
get_attr节点。在Transformer中,这被重写以将新的get_attr节点插入到输出图中。- 参数:
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- placeholder(target, args, kwargs)[source]#
执行
placeholder节点。在Transformer中,这被重写以将新的placeholder插入到输出图中。- 参数:
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型:
注意
此 API 的向后兼容性已得到保证。
- torch.fx.replace_pattern(gm, pattern, replacement)[source]#
匹配 GraphModule (
gm) 的 Graph 中的所有可能的非重叠的运算符及其数据依赖项 (pattern) 的集合,然后将每个匹配的子图替换为另一个子图 (replacement)。- 参数:
gm (GraphModule) – 包装要操作的 Graph 的 GraphModule
pattern (Callable | GraphModule) – 要在
gm中匹配以进行替换的子图replacement (Callable | GraphModule) – 用于替换
pattern的子图
- 返回:
一个
Match对象列表,表示原始图中pattern匹配到的位置。如果没有任何匹配,列表将为空。Match定义为:class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
- 返回类型:
List[Match]
示例
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]) def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上面的代码将首先在
traced_module的forward方法中匹配pattern。模式匹配是基于使用-定义关系进行的,而不是节点名称。例如,如果您的pattern中有p = torch.cat([a, b]),您可以在原始的forward函数中匹配m = torch.cat([a, b]),即使变量名不同(pvsm)。pattern中的return语句仅根据其值进行匹配;它可能匹配也可能不匹配较大图中的return语句。换句话说,模式不必延伸到较大图的末尾。当模式匹配成功时,它将被从较大函数中移除,并被
replacement替换。如果较大函数中有多个pattern的匹配项,则每个不重叠的匹配项都将被替换。如果存在匹配项重叠,则将替换重叠匹配项集合中的第一个找到的匹配项。(这里的“第一个”定义为在节点的使用-定义关系的拓扑排序中的第一个。在大多数情况下,第一个节点是直接出现在self之后的参数,而最后一个节点是函数返回的任何内容。)需要注意的一个重要事项是,
pattern可调用对象(Callable)的参数必须在可调用对象本身中使用,并且replacement可调用对象的参数必须与模式匹配。第一个规则解释了为什么在上面的代码块中,forward函数具有参数x, w1, w2,而pattern函数仅具有参数w1, w2。pattern没有使用x,因此它不应该将x指定为参数。作为第二个规则的示例,考虑替换def pattern(x, y): return torch.neg(x) + torch.relu(y)
替换
def replacement(x, y): return torch.relu(x)
在这种情况下,
replacement需要与pattern相同数量的参数(x和y),即使参数y在replacement中未被使用。调用
subgraph_rewriter.replace_pattern后,生成的 Python 代码如下所示def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
此 API 的向后兼容性已得到保证。
- torch.fx.traceback.annotate(annotation_dict)[source]#
临时为当前跟踪上下文添加自定义注释。从此跟踪上下文生成的 fx_node 将在 node.metadata[“custom”] 字段中包含自定义注释。
此上下文管理器允许您通过更新全局 current_meta[“custom”] 字典,将任意元数据插入 PT2 跟踪系统。注释在上下文退出后会自动恢复。
梯度累积节点不会被注释。
此功能面向需要在使用导出跟踪期间将额外元数据附加到 fx 节点(例如,用于调试、分析或外部工具)的高级用户。
注意
此 API **不向后兼容**,并且可能在未来的版本中发生变化。
注意
此 API 不兼容 fx.symbolic_trace 或 jit.trace。它旨在与 PT2 系列的跟踪器一起使用,例如 torch.export 和 dynamo。
- 参数:
annotation_dict (dict) – 要注入到 FX 跟踪元数据中的自定义键值对字典。
示例
退出上下文后,自定义注释将被移除。
>>> with annotate({"source": "custom_pass", "tag": 42}): ... pass # Your computation here
警告
此 API 是实验性的,并且不向后兼容。