torch.fx#
创建于: 2020年12月15日 | 最后更新于: 2025年6月12日
概述#
FX 是一个用于开发人员转换 `nn.Module` 实例的工具包。FX 包含三个主要组件:符号跟踪器、中间表示和Python 代码生成。这些组件的演示
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符号跟踪器 对 Python 代码执行“符号执行”。它将假的(称为代理 Proxy)值输入代码。对这些代理的操作会被记录下来。有关符号跟踪的更多信息可以在 symbolic_trace()
和 Tracer
文档中找到。
中间表示 是在符号跟踪期间记录的操作的容器。它由一系列节点组成,这些节点表示函数输入、调用点(指向函数、方法或 `torch.nn.Module` 实例)和返回值。有关 IR 的更多信息可以在 Graph
的文档中找到。IR 是应用转换的格式。
Python 代码生成 是 FX 成为 Python 到 Python(或模块到模块)转换工具包的原因。对于每个 Graph IR,我们可以创建与 Graph 语义匹配的有效 Python 代码。此功能封装在 GraphModule
中,它是一个 `torch.nn.Module` 实例,包含一个 Graph
以及一个从 Graph 生成的 `forward` 方法。
总而言之,这个组件管道(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。此外,这些组件也可以分开使用。例如,符号跟踪可以单独用于捕获代码的某种形式以进行分析(而非转换)。代码生成可用于以编程方式生成模型,例如从配置文件生成。FX 有许多用途!
可以在 examples 仓库中找到几个示例转换。
编写转换#
什么是 FX 转换?本质上,它是一个如下所示的函数。
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
您的转换将接收一个 `torch.nn.Module`,从中获取一个 `Graph`,进行一些修改,然后返回一个新的 `torch.nn.Module`。您应该将您的 FX 转换返回的 `torch.nn.Module` 视为与常规 `torch.nn.Module` 相同——您可以将其传递给另一个 FX 转换,传递给 TorchScript,或者直接运行它。确保您的 FX 转换的输入和输出都是 `torch.nn.Module` 将允许组合。
注意
也可以修改现有的 `GraphModule` 而不是创建一个新的,如下所示
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
请注意,您必须调用 `GraphModule.recompile()` 将 `GraphModule` 上生成的 `forward()` 方法与修改后的 `Graph` 保持同步。
鉴于您传递了一个已被跟踪为 `Graph` 的 `torch.nn.Module`,现在您可以通过两种主要方法来构建一个新的 `Graph`。
Graph 快速入门#
有关图语义的完整介绍可以在 Graph
文档中找到,但我们在此处介绍基础知识。`Graph` 是一个表示 `GraphModule` 中方法的数据结构。它所需的信息是
方法有哪些输入?
方法中运行了哪些操作?
方法的输出(即返回值)是什么?
所有这三个概念都用 `Node` 实例表示。让我们用一个简短的例子来看看我们的意思
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
这里我们为了演示定义了一个模块 `MyModule`,实例化它,对其进行符号跟踪,然后调用 `Graph.print_tabular()` 方法来打印一个表格,显示此 `Graph` 的节点
opcode |
name |
target |
args |
kwargs |
---|---|---|---|---|
placeholder |
x |
x |
() |
{} |
get_attr |
linear_weight |
linear.weight |
() |
{} |
call_function |
add_1 |
(x, linear_weight) |
{} |
|
call_module |
linear_1 |
linear |
(add_1,) |
{} |
call_method |
relu_1 |
relu |
(linear_1,) |
{} |
call_function |
sum_1 |
<built-in method sum …> |
(relu_1,) |
{‘dim’: -1} |
call_function |
topk_1 |
<built-in method topk …> |
(sum_1, 3) |
{} |
output |
output |
output |
(topk_1,) |
{} |
我们可以利用这些信息来回答我们之前提出的问题。
方法有哪些输入?在 FX 中,方法的输入通过特殊的 `placeholder` 节点指定。在本例中,我们有一个名为 x 的单一 `placeholder` 节点,其 `target` 为 `x`,这意味着我们有一个单独的(非自生的)参数 x。
方法中有哪些操作? `get_attr`、`call_function`、`call_module` 和 `call_method` 节点表示方法中的操作。所有这些操作语义的完整处理可以在
Node
文档中找到。方法的返回值是什么? `Graph` 中的返回值由特殊的 `output` 节点指定。
既然我们现在知道了代码在 FX 中如何表示的基础知识,我们就可以开始探索如何编辑 `Graph` 了。
Graph 操作#
直接 Graph 操作#
一种构建新Graph
的方法是直接操作旧图。为了方便起见,我们可以直接获取通过符号跟踪得到的Graph
并对其进行修改。例如,假设我们希望将 torch.add()
调用替换为 torch.mul()
调用。
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
我们也可以进行更复杂的Graph
重写,例如删除或追加节点。为了帮助进行这些转换,FX提供了用于转换图的实用函数,这些函数可以在Graph
文档中找到。下面是一个使用这些 API 追加 torch.relu()
调用的示例。
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
对于仅包含替换的简单转换,您还可以利用子图重写器。
使用 replace_pattern() 进行子图重写#
FX 还提供了一个更高级别的直接图操作自动化。 replace_pattern()
API 本质上是一个用于编辑Graph
的“查找/替换”工具。它允许您指定一个 pattern
和一个 replacement
函数,它将跟踪这些函数,在 pattern
图中找到操作组的实例,并将这些实例替换为 replacement
图的副本。这有助于大大自动化繁琐的图操作代码,因为随着转换变得越来越复杂,这些代码可能会变得难以处理。
Proxy/重跟踪#
操作Graph
的另一种方法是重用符号跟踪中使用的 Proxy
机制。例如,假设我们想编写一个将 PyTorch 函数分解为更小操作的转换。它会将每个 F.relu(x)
调用转换为 (x > 0) * x
。一种可能性是执行必要的图重写,在 F.relu
之后插入比较和乘法,然后清理原始的 F.relu
。但是,我们可以通过使用 Proxy
对象来自动捕获操作并将它们附加到 Graph
中来自动化此过程。
要使用此方法,我们将想要插入的操作编写为常规 PyTorch 代码,并使用 Proxy
对象作为参数来调用该代码。这些 Proxy
对象将捕获在它们上执行的操作,并将它们附加到 Graph
。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式图操作之外,使用 Proxy
还允许您将重写规则指定为原生 Python 代码。对于需要大量重写规则的转换(例如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。请注意,在调用 Proxy
时,我们还传递了一个指向底层变量 graph
的跟踪器。这样做是为了以防图中的操作是 n 元的(例如,add 是一个二元运算符),对 Proxy
的调用不会创建图跟踪器的多个实例,这可能会导致意外的运行时错误。我们推荐使用 Proxy
的这种方法,特别是当底层运算符不能安全地假定为一元运算符时。
解释器模式#
FX 中一个有用的代码组织模式是遍历 Graph
中的所有 Node
并执行它们。这可以用于多种目的,包括对图中值流动的运行时分析,或通过使用 Proxy
进行重跟踪来转换代码。例如,假设我们想运行一个 GraphModule
,并在运行时观察节点上的 torch.Tensor
的形状和 dtype 属性。这可能看起来像这样
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
正如你所见,FX 的完整解释器并不复杂,但它可能非常有用。为了方便使用这种模式,我们提供了 Interpreter
类,它以一种可以通过方法重写来覆盖解释器执行的某些方面的方式来包含上述逻辑。
除了执行操作之外,我们还可以通过将 Proxy
值馈送到解释器中来生成新的 Graph
。类似地,我们提供了 Transformer
类来包含这种模式。 Transformer
的行为类似于 Interpreter
,但不是调用 run
方法来获取模块的具体输出值,而是调用 Transformer.transform()
方法来返回一个新的 GraphModule
,该模块已根据您作为覆盖方法安装的任何转换规则进行了修改。
调试#
引言#
在编写转换过程中,我们的代码可能并不总是正确的。在这种情况下,我们可能需要进行一些调试。关键是向后工作:首先,检查生成的模块调用的结果,以证明或证伪其正确性。然后,检查和调试生成的代码。然后,调试导致生成代码的转换过程。
如果您不熟悉调试器,请参阅辅助部分 可用的调试器。
转换编写中的常见陷阱#
set
的迭代顺序不确定。在 Python 中,set
数据类型是无序的。使用set
来包含Node
等对象的集合,可能会导致意外的不确定性。一个例子是迭代一组Node
以将它们插入Graph
。由于set
数据类型是无序的,输出程序中操作的顺序将是不确定的,并且可以在程序调用之间发生变化。推荐的替代方法是使用dict
数据类型,该数据类型在 Python 3.7(以及 cpython 3.6)中是按插入顺序排序的。dict
可以等效地用作集合,方法是将要进行去重的值存储在dict
的键中。
检查模块的正确性#
由于大多数深度学习模块的输出由浮点 torch.Tensor
实例组成,因此检查两个 torch.nn.Module
的结果是否等同并不像执行简单的相等性检查那样直接。为了说明这一点,我们举一个例子
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在这里,我们尝试使用 ==
相等运算符来检查两个深度学习模型的数值是否相等。然而,这并不明确。
该运算符返回一个张量而不是一个布尔值,这使得它不明确,并且因为浮点值的比较应该使用误差范围(或 epsilon)来考虑浮点运算的非交换性(有关更多详细信息,请参阅此处)。我们可以改用 torch.allclose()
,它将根据相对和绝对容差阈值给出近似比较
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们用来检查转换后的模块是否按照我们期望的方式与参考实现进行行为的第一个工具。
调试生成的代码#
由于 FX 会在 GraphModule
上生成 forward()
函数,因此使用 print
语句或 pdb
等传统调试技术并不像直接使用它们那样简单。幸运的是,我们有几种技术可以用于调试生成的代码。
使用 pdb
#
调用 pdb
来进入正在运行的程序。虽然代表 Graph
的代码不在任何源文件中,但当调用前向传递时,我们仍然可以使用 pdb
手动进入它。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
打印生成的代码#
如果您想多次运行相同的代码,那么使用 pdb
跳转到正确的代码可能会有些乏味。在这种情况下,一种方法是直接将生成的 forward
传递复制到您的代码中,并从那里进行检查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 GraphModule
的 to_folder
函数#
GraphModule.to_folder()
是 GraphModule
中的一个方法,允许您将生成的 FX 代码转储到一个文件夹中。尽管将 forward 传递复制到代码中通常就足够了(如 打印生成的代码 中所述),但使用 to_folder
检查模块和参数可能更容易。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
运行上面的示例后,我们可以查看 foo/module.py
中的代码,并根据需要对其进行修改(例如,添加 print
语句或使用 pdb
)来调试生成的代码。
调试转换#
现在我们已经确定转换产生了错误的代码,是时候调试转换本身了。首先,我们将检查文档中的 符号跟踪的限制 部分。一旦我们验证了跟踪正在按预期工作,目标就是找出在我们的 GraphModule
转换过程中出了什么问题。 编写转换 中可能有一个快速的答案,但如果没有,有几种方法可以检查我们跟踪的模块
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
output output output (add,) {}
"""
使用上面提供的实用函数,我们可以比较我们应用转换之前和之后的跟踪模块。有时,简单的目视比较足以追溯错误。如果仍然不清楚出了什么问题,像 pdb
这样的调试器可能是一个不错的下一步。
根据上面的例子,考虑以下代码
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
使用上面的示例,假设 print(traced)
的调用显示我们的转换中存在错误。我们想使用调试器找出问题所在。我们启动一个 pdb
会话。通过在 transform_graph(traced)
上设置断点,然后按 s
“进入” transform_graph(traced)
的调用,我们可以看到转换过程中发生了什么。
通过编辑 print_tabular
方法来打印图中节点的不同属性,我们可能会有所收获。(例如,我们可能想查看节点的 input_nodes
和 users
。)
可用的调试器#
最常见的 Python 调试器是 pdb。您可以通过在命令行中键入 python -m pdb FILENAME.py
来以“调试模式”启动程序,其中 FILENAME
是您要调试的文件名。之后,您可以使用 pdb
调试器命令分步执行正在运行的程序。通常在启动 pdb
时设置一个断点(b LINE-NUMBER
),然后调用 c
来运行程序直到该点。这可以防止您必须逐行执行(使用 s
或 n
)才能到达您想检查的代码部分。或者,您可以在要中断的行之前编写 import pdb; pdb.set_trace()
。如果您添加了 pdb.set_trace()
,那么当您运行程序时,它将自动进入调试模式。(换句话说,您只需在命令行中键入 python FILENAME.py
而不是 python -m pdb FILENAME.py
。)一旦您以调试模式运行文件,您就可以使用某些命令逐步执行代码并检查程序的内部状态。网上有许多优秀的关于 pdb
的教程,包括 RealPython 的 “使用 Pdb 进行 Python 调试”。
PyCharm 或 VSCode 等 IDE 通常内置调试器。在您的 IDE 中,您可以选择 a) 通过在 IDE 中打开终端窗口(例如,在 VSCode 中选择 View → Terminal)来使用 pdb
,或者 b) 使用内置调试器(通常是 pdb
的图形化包装器)。
符号跟踪的限制#
FX 使用**符号跟踪**(也称为符号执行)系统,将程序语义捕获为可转换/可分析的形式。该系统是**跟踪**的,因为它会执行程序(实际上是 torch.nn.Module
或函数)来记录操作。它是**符号**的,因为在此执行期间流经程序的数据不是真实数据,而是符号(在 FX 术语中是 Proxy
)。
尽管符号跟踪适用于大多数神经网络代码,但它也有一些限制。
动态控制流#
符号跟踪的主要限制是它目前不支持*动态控制流*。也就是说,循环或 if
语句,其条件可能取决于程序的输入值。
例如,我们来看下面的程序
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() > 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if
语句的条件依赖于 x.sum()
的值,而 x.sum()
又依赖于 x
的值,而 x
是一个函数输入。由于 x
可以改变(例如,如果您将新的输入张量传递给被跟踪的函数),这就是*动态控制流*。堆栈跟踪会向上追溯到您的代码,向您显示这种情况发生的位置。
静态控制流#
另一方面,所谓的*静态控制流*是支持的。静态控制流是循环或 if
语句,其值在调用之间不会改变。通常,在 PyTorch 程序中,这种控制流是由于代码根据超参数来决定模型的架构而产生的。作为具体示例
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if self.do_activation
的 if 语句不依赖于任何函数输入,因此它是静态的。 do_activation
可以被视为一个超参数,并且不同 MyModule
实例(具有该参数的不同值)的跟踪具有不同的代码。这是一个有效的模式,得到了符号跟踪的支持。
许多动态控制流的实例是语义上的静态控制流。这些实例可以通过移除对输入值的依赖来支持符号跟踪,例如,将值移动到 Module
属性,或者在符号跟踪期间将具体值绑定到参数
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
在真正动态控制流的情况下,包含这些代码的程序部分可以被跟踪为对 Method(请参阅 使用 Tracer 类自定义跟踪)或 Function(请参阅 wrap()
)的调用,而不是跟踪它们。
非 torch
函数#
FX 使用 __torch_function__
作为它拦截调用的机制(有关更多信息,请参阅技术概述)。一些函数,例如内置的 Python 函数或 math
模块中的函数,不包含在 __torch_function__
中,但我们仍然希望在符号跟踪中捕获它们。例如
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
错误告诉我们内置函数 len
不受支持。我们可以通过使用 wrap()
API 来使类似这样的函数在跟踪中记录为直接调用。
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用 Tracer
类自定义跟踪#
Tracer
类是 symbolic_trace
实现的基础类。可以通过继承 Tracer 来自定义跟踪行为,如下所示
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
叶子模块#
叶子模块是在符号跟踪中作为调用出现,而不是被跟踪的模块。默认的叶子模块集是标准的 torch.nn
模块实例集。例如
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
叶子模块集可以通过重写 Tracer.is_leaf_module()
来自定义。
杂项#
张量构造函数(例如
torch.zeros
、torch.ones
、torch.rand
、torch.randn
、torch.sparse_coo_tensor
)目前是不可跟踪的。可以安全地使用确定性构造函数(
zeros
、ones
),并且它们生成的值将被嵌入到跟踪中作为常量。这只有在这些构造函数的参数引用动态输入大小时才成问题。在这种情况下,ones_like
或zeros_like
可能是可行的替代方案。不确定性构造函数(
rand
、randn
)将在跟踪中嵌入一个随机值。这可能不是预期行为。一种解决方法是将torch.randn
包装在torch.fx.wrap
函数中,然后改用它。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
此行为可能在未来版本中得到修复。
类型注解
Python 3 风格的类型注解(例如
func(x : torch.Tensor, y : int) -> torch.Tensor
)是支持的,并且将被符号跟踪保留。不支持 Python 2 风格的类型注解
# type: (torch.Tensor, int) -> torch.Tensor
。目前不支持函数内局部名称的注解。
关于
training
标志和子模块的注意事项在使用像
torch.nn.functional.dropout
这样的函数时,通常会将 training 参数传递为self.training
。在 FX 跟踪期间,这很可能会被烘焙为常量值。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! Mismatched elements: 15 / 15 (100.0%) Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) """
然而,当使用标准的
nn.Dropout()
子模块时,training 标志会被封装起来,并且由于nn.Module
对象模型的保留,可以被更改。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
因此,请考虑将动态交互
training
标志的模块标记为叶子模块。
API 参考#
- torch.fx.symbolic_trace(root, concrete_args=None)[source]#
符号跟踪 API
给定一个
nn.Module
或函数实例root
,此函数将返回一个GraphModule
,该模块通过记录在跟踪root
期间看到的运算来构建。concrete_args
允许您部分特化函数,无论是为了移除控制流还是数据结构。例如
def f(a, b): if b == True: return a else: return a * 2
FX 通常无法跟踪此项,因为其中包含控制流。但是,我们可以使用 concrete_args 来特化 b 的值以进行跟踪。
f = fx.symbolic_trace(f, concrete_args={"b": False}) assert f(3, False) == 6
请注意,虽然您仍然可以传入 b 的不同值,但它们将被忽略。
我们还可以使用 concrete_args 来消除函数中的数据结构处理。这会使用 pytrees 来展平您的输入。为了避免过度特化,请为不应特化的值传入 fx.PH。例如:
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace( f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}} ) assert f({"a": 1, "b": 2, "c": 4}) == 7
- 参数
root (Union[torch.nn.Module, Callable]) – 要被跟踪并转换为 Graph 表示的模块或函数。
concrete_args (Optional[Dict[str, any]]) – 要部分特化的输入。
- 返回
一个从
root
记录的操作创建的模块。- 返回类型
注意
此 API 的向后兼容性已得到保证。
- torch.fx.wrap(fn_or_name)[source]#
此函数可以在模块级别范围调用,以将 fn_or_name 注册为“叶子函数”。“叶子函数”将在 FX 跟踪中保留为 CallFunction 节点,而不是被跟踪。
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap("my_custom_function") def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
此函数也可以等效地用作装饰器。
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
包装后的函数可以被视为“叶子函数”,类似于“叶子模块”的概念,即它们是在 FX 跟踪中保留为调用的函数,而不是被跟踪。
- 参数
fn_or_name (Union[str, Callable]) – 调用时要插入图中的函数或全局函数的名称。
注意
此 API 的向后兼容性已得到保证。
- class torch.fx.GraphModule(*args, **kwargs)[source]#
GraphModule 是一个由 fx.Graph 生成的 nn.Module。Graphmodule 具有
graph
属性,以及从该graph
生成的code
和forward
属性。警告
当
graph
被重新赋值时,code
和forward
将被自动重新生成。但是,如果您在不重新赋值graph
属性本身的情况下编辑graph
的内容,则必须调用recompile()
来更新生成的代码。注意
此 API 的向后兼容性已得到保证。
- __init__(root, graph, class_name='GraphModule')[source]#
构造一个 GraphModule。
- 参数
root (Union[torch.nn.Module, Dict[str, Any]) –
root
可以是 nn.Module 实例,也可以是映射字符串到任何属性类型的 Dict。如果root
是一个 Module,则 Graph 的 Nodes 的target
字段中对基于 Module 的对象的任何引用(通过限定名)将从root
的 Module 层次结构中的相应位置复制到 GraphModule 的模块层次结构中。如果root
是一个 dict,则将在 dict 的键中直接查找 Node 的target
中找到的限定名。由 Dict 映射的对象将复制到 GraphModule 的模块层次结构中的适当位置。graph (Graph) –
graph
包含此 GraphModule 应使用以生成代码的节点。class_name (str) –
name
表示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息将报告为源自GraphModule
。将此设置为root
的原始名称或在您的转换上下文中具有意义的名称可能很有帮助。
注意
此 API 的向后兼容性已得到保证。
- add_submodule(target, m)[source]#
将给定的子模块添加到
self
。这会在目标路径中尚未存在时安装空的 Modules,如果它们是
target
的子路径。- 参数
- 返回
- 子模块是否可以被插入。要
使此方法返回 True,由
target
表示的链中的每个对象必须满足以下条件之一:a) 尚不存在,或 b) 引用一个nn.Module
(不是参数或其他属性)。
- 返回类型
注意
此 API 的向后兼容性已得到保证。
- delete_all_unused_submodules()[source]#
从
self
中删除所有未使用的子模块。如果以下任一条件成立,则认为一个 Module “被使用”:1. 它具有被使用的子模块。2. 它的 forward 通过
call_module
节点直接调用。3. 它具有一个非 Module 属性,该属性从get_attr
节点被使用。可以通过调用此方法来清理
nn.Module
,而无需手动对每个未使用的子模块调用delete_submodule
。注意
此 API 的向后兼容性已得到保证。
- delete_submodule(target)[source]#
从
self
中删除给定的子模块。如果
target
不是有效目标,则该模块不会被删除。- 参数
target (str) – 新子模块的完全限定字符串名称(请参阅
nn.Module.get_submodule
中的示例,了解如何指定完全限定字符串)。- 返回
- target 字符串是否引用了
我们想要删除的子模块。返回值为
False
意味着target
不是对子模块的有效引用。
- 返回类型
注意
此 API 的向后兼容性已得到保证。
- print_readable(print_output=True, include_stride=False, include_device=False, colored=False, *, fast_sympy_print=False)[source]#
返回为当前 GraphModule 及其子 GraphModule 生成的 Python 代码。
警告
此 API 是实验性的,**不**向后兼容。
- class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#
Graph
是 FX 中间表示的主要数据结构。它由一系列Node
组成,每个Node
代表一个调用点(或其他语法结构)。Node
的列表共同构成了一个有效的 Python 函数。例如,以下代码
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk( torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 ) m = MyModule() gm = torch.fx.symbolic_trace(m)
将产生以下 Graph:
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1
有关
Graph
中表示的操作的语义,请参阅Node
。注意
此 API 的向后兼容性已得到保证。
- __init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#
构造一个空的 Graph。
注意
此 API 的向后兼容性已得到保证。
- call_function(the_function, args=None, kwargs=None, type_expr=None, name=None)[source]#
将
call_function
Node
插入到Graph
中。call_function
节点表示对由the_function
指定的 Python 可调用对象进行调用。- 参数
the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 运算符、Python 函数或
builtins
或operator
命名空间中的成员。args (Optional[Tuple[Argument, ...]]) – 要传递给所调用函数的 positional arguments。
kwargs (Optional[Dict[str, Argument]]) – 要传递给所调用函数的 keyword arguments。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
name (Optional[str]) – 节点的名称。如果未指定,则设置为 None。
- 返回
新创建并插入的
call_function
节点。- 返回类型
注意
此方法与
Graph.create_node()
适用相同的插入点和类型表达式规则。注意
此 API 的向后兼容性已得到保证。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[source]#
将
call_method
Node
插入到Graph
中。call_method
节点表示对args
的第 0 个元素上的给定方法进行调用。- 参数
method_name (str) – 应用于 self 参数的方法名称。例如,如果 args[0] 是代表
Tensor
的Node
,则要在该Tensor
上调用relu()
,则将relu
传递给method_name
。args (Optional[Tuple[Argument, ...]]) – 要传递给所调用方法的 positional arguments。请注意,这 *应该* 包含一个
self
参数。kwargs (Optional[Dict[str, Argument]]) – 要传递给所调用方法的 keyword arguments。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回
新创建并插入的
call_method
节点。- 返回类型
注意
此方法与
Graph.create_node()
适用相同的插入点和类型表达式规则。注意
此 API 的向后兼容性已得到保证。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[source]#
将
call_module
Node
插入到Graph
中。call_module
节点表示对Module
层次结构中的Module
的 forward() 函数的调用。- 参数
module_name (str) – 要调用的
Module
层次结构中Module
的限定名称。例如,如果被跟踪的Module
具有一个名为foo
的子模块,该子模块又有一个名为bar
的子模块,则应将限定名foo.bar
作为module_name
来调用该模块。args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的 positional arguments。请注意,这不应包含
self
参数。kwargs (Optional[Dict[str, Argument]]) – 要传递给所调用方法的 keyword arguments。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回
新创建并插入的
call_module
节点。- 返回类型
注意
此方法与
Graph.create_node()
适用相同的插入点和类型表达式规则。注意
此 API 的向后兼容性已得到保证。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source]#
创建一个
Node
并将其添加到当前插入点的Graph
中。请注意,当前插入点可以通过Graph.inserting_before()
和Graph.inserting_after()
进行设置。- 参数
op (str) – 此 Node 的操作码。可以是 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’ 之一。这些操作码的语义在
Graph
文档字符串中进行了描述。args (Optional[Tuple[Argument, ...]]) – 此节点的参数元组。
kwargs (Optional[Dict[str, Argument]]) – 此 Node 的 kwargs。
name (Optional[str]) – 节点的可选字符串名称。这将影响生成的 Python 代码中赋值值的名称。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回
新创建并插入的节点。
- 返回类型
注意
此 API 的向后兼容性已得到保证。
- eliminate_dead_code(is_impure_node=None)[source]#
根据每个节点的用户数以及节点是否具有任何副作用,从图中移除所有死代码。调用此方法前,图必须已拓扑排序。
- 参数
- 返回
图是否因该传递而发生更改。
- 返回类型
示例
在消除死代码之前,下面的 a = x + 1 中的 a 没有用户,因此可以从图中消除,而不会产生任何影响。
def forward(self, x): a = x + 1 return x + self.attr_1
消除死代码后,a = x + 1 已被移除,其余 forward 保持不变。
def forward(self, x): return x + self.attr_1
警告
死代码消除有一些启发式方法可以避免移除具有副作用的节点(参见 Node.is_impure),但总的来说覆盖率非常差,因此您应该假定此方法不可靠,除非您知道您的 FX 图完全由函数式运算组成,或者您提供了自己的自定义函数来检测具有副作用的节点。
注意
此 API 的向后兼容性已得到保证。
- erase_node(to_erase)[source]#
从
Graph
中擦除一个Node
。如果该节点在Graph
中仍有用户,则会引发异常。- 参数
to_erase (Node) – 要从
Graph
中擦除的Node
。
注意
此 API 的向后兼容性已得到保证。
- find_nodes(*, op, target=None, sort=True)[source]#
允许快速查询节点。
- 参数
- 返回
具有指定操作和 target 的节点的可迭代对象。
警告
此 API 是实验性的,**不**向后兼容。
- get_attr(qualified_name, type_expr=None)[source]#
将 `get_attr` 节点插入到 Graph 中。`get_attr` `Node` 表示从 `Module` 层次结构中获取属性。
- 参数
qualified_name (str) – 要检索的属性的完全限定名称。例如,如果跟踪的 Module 有一个名为 `foo` 的子模块,该子模块有一个名为 `bar` 的子模块,该子模块有一个名为 `baz` 的属性,则应将限定名称 `foo.bar.baz` 作为 `qualified_name` 传递。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回
新创建并插入的 `get_attr` 节点。
- 返回类型
注意
此方法的插入点和类型表达式规则与 `Graph.create_node` 相同。
注意
此 API 的向后兼容性已得到保证。
- graph_copy(g, val_map, return_output_node=False)[source]#
将给定图中的所有节点复制到 `self` 中。
- 参数
- 返回
在 `self` 中等效于 `g` 中输出值的值,如果 `g` 有一个 `output` 节点。否则为 `None`。
- 返回类型
Optional[Union[tuple[‘Argument’, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
此 API 的向后兼容性已得到保证。
- inserting_after(n=None)[source]#
- 设置 create_node 和相关方法将插入到图中的位置。
在 `with` 语句中使用时,这将临时设置插入点,然后在 `with` 语句退出时将其恢复。
with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
参数 (Args)
- n (Optional[Node]): 在其之前插入的节点。如果为 None,则将在图的开头之后插入。
整个图的开头。
- 返回
一个资源管理器,将在 `__exit__` 时恢复插入点。
注意
此 API 的向后兼容性已得到保证。
- inserting_before(n=None)[source]#
- 设置 create_node 和相关方法将插入到图中的位置。
在 `with` 语句中使用时,这将临时设置插入点,然后在 `with` 语句退出时将其恢复。
with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
参数 (Args)
- n (Optional[Node]): 在其之前插入的节点。如果为 None,则将在图的开头之前插入。
整个图的开头。
- 返回
一个资源管理器,将在 `__exit__` 时恢复插入点。
注意
此 API 的向后兼容性已得到保证。
- lint()[source]#
对该 Graph 运行各种检查以确保其格式正确。特别是:- 检查节点是否具有正确的归属(属于此 Graph)- 检查节点是否按拓扑顺序出现- 如果此 Graph 拥有 GraphModule,则检查目标是否存在于该 GraphModule 中。
注意
此 API 的向后兼容性已得到保证。
- node_copy(node, arg_transform=<function Graph.<lambda>>)[source]#
将一个节点从一个图复制到另一个图。`arg_transform` 需要将参数从节点的图转换为 self 的图。示例
# Copying all the nodes in `g` into `new_graph` g: torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
- 参数
- 返回类型
注意
此 API 的向后兼容性已得到保证。
- property nodes: _node_list#
获取构成此 Graph 的节点列表。请注意,此 `Node` 列表表示是双向链表。迭代期间的修改(例如,删除节点,添加节点)是安全的。
节点双向链表。请注意,可以对此列表调用 `reversed` 来切换迭代顺序。
- 返回
节点双向链表。请注意,可以对此列表调用 `reversed` 来切换迭代顺序。
- on_generate_code(make_transformer)[source]#
生成 Python 代码时注册转换器函数
- 参数 (Args)
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc])
返回要注册的代码转换器的函数。此函数由 `on_generate_code` 调用以获取代码转换器。
此函数还接收当前注册的代码转换器(如果未注册任何内容,则为 None)作为输入,以防不希望覆盖它。这对于链接代码转换器很有用。
- 返回
一个上下文管理器,在 `with` 语句中使用时,可以自动恢复先前注册的代码转换器。
示例
gm: fx.GraphModule = ... # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual # debugging with the PDB library. def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): gm.graph.on_generate_code(lambda _: insert_pdb) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb( current_trans(body) if current_trans else body ) ) ) gm.recompile() gm(*inputs) # drops into pdb
此函数还可以用作上下文管理器,其好处是可以自动恢复先前注册的代码转换器。
# ... continue from previous example with gm.graph.on_generate_code(lambda _: insert_pdb): # do more stuff with `gm`... gm.recompile() gm(*inputs) # drops into pdb # now previous code transformer is restored (but `gm`'s code with pdb # remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告
此 API 是实验性的,**不**向后兼容。
- output(result, type_expr=None)[source]#
将 `output` `Node` 插入 `Graph` 中。`output` 节点表示 Python 代码中的 `return` 语句。`result` 是应返回的值。
- 参数
result (Argument) – 要返回的值。
type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
注意
此方法的插入点和类型表达式规则与 `Graph.create_node` 相同。
注意
此 API 的向后兼容性已得到保证。
- placeholder(name, type_expr=None, default_value)[source]#
将 `placeholder` 节点插入到 Graph 中。`placeholder` 表示函数输入。
- 参数
name (str) – 输入值的名称。这对应于此 `Graph` 所表示的函数的 positional 参数的名称。
type_expr (Optional[Any]) – 可选的类型注释,表示此节点输出的 Python 类型。在某些情况下(例如,当函数随后用于 TorchScript 编译时)需要此项以进行正确的代码生成。
default_value (Any) – 此函数参数应采用的默认值。注意:为了允许 `None` 作为默认值,应将 `inspect.Signature.empty` 作为此参数传递,以指定该参数 _没有_ 默认值。
- 返回类型
注意
此方法的插入点和类型表达式规则与 `Graph.create_node` 相同。
注意
此 API 的向后兼容性已得到保证。
- python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)[source]#
将此 `Graph` 转换为有效的 Python 代码。
- 参数
root_module (str) – 用于查找限定名称目标的根模块的名称。通常是 ‘self’。
- 返回
src:表示对象的 Python 源代码 globals:`src` 中的全局名称到它们引用的对象的字典。
- 返回类型
一个 PythonCode 对象,包含两个字段:
注意
此 API 的向后兼容性已得到保证。
- class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source]#
`Node` 是表示 `Graph` 中单个操作的数据结构。在大多数情况下,Nodes 表示对各种实体的调用站点,例如运算符、方法和 Modules(一些例外包括指定函数输入和输出的节点)。每个 `Node` 都有一个由其 `op` 属性指定的函数。每个 `op` 值的 `Node` 语义如下:
`placeholder` 表示函数输入。`name` 属性指定该值将被赋予的名称。`target` 同样是参数的名称。`args` 持有:1) 空,或 2) 一个表示函数输入的默认参数。`kwargs` 是不关心的。Placeholders 对应于图打印输出中的函数参数(例如 `x`)。
`get_attr` 从模块层次结构检索参数。`name` 同样是赋值给获取结果的名称。`target` 是参数在模块层次结构中位置的完全限定名称。`args` 和 `kwargs` 是不关心的。
`call_function` 将一个自由函数应用于某些值。`name` 同样是要分配给的值的名称。`target` 是要应用的函数。`args` 和 `kwargs` 表示函数参数,遵循 Python 调用约定。
`call_module` 将模块层次结构中的模块的 `forward()` 方法应用于给定参数。`name` 如前所述。`target` 是要调用的模块在模块层次结构中的完全限定名称。`args` 和 `kwargs` 表示调用模块的参数, _不包括 self 参数_。
`call_method` 调用值上的方法。`name` 类似。`target` 是应用于 `self` 参数的方法的字符串名称。`args` 和 `kwargs` 表示调用模块的参数, _包括 self 参数_。
`output` 在其 `args[0]` 属性中包含跟踪函数的输出。这对应于 Graph 打印输出中的“return”语句。
注意
此 API 的向后兼容性已得到保证。
- property all_input_nodes: list['Node']#
返回此 Node 的所有输入 Nodes。这等效于遍历 `args` 和 `kwargs` 并仅收集是 Nodes 的值。
- 返回
此 `Node` 的 `args` 和 `kwargs` 中出现的 `Nodes` 的列表,按顺序排列。
- append(x)[source]#
在图的节点列表中在此节点之后插入 `x`。等效于 `self.next.prepend(x)`
- 参数
x (Node) – 要放在此节点之后的节点。必须是同一图的成员。
注意
此 API 的向后兼容性已得到保证。
- property args: tuple[Union[tuple['Argument', ...], collections.abc.Sequence['Argument'], collections.abc.Mapping[str, 'Argument'], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...]#
此 `Node` 的参数元组。参数的解释取决于节点的 opcode。有关更多信息,请参阅 `Node` 文档。
允许对该属性进行赋值。所有使用和用户的计数将在赋值时自动更新。
- format_node(placeholder_names=None, maybe_return_typename=None)[source]#
返回 `self` 的描述性字符串表示。
此方法可以不带参数地用作调试工具。
此函数还在 `Graph` 的 `__str__` 方法中用作内部帮助程序。`placeholder_names` 和 `maybe_return_typename` 中的字符串共同构成了此 Graph 的 GraphModule 的 autogenerated `forward` 函数的签名。`placeholder_names` 和 `maybe_return_typename` 不应在其他地方使用。
- 参数
- 返回
- 如果 1) 我们在 `Graph` 的 `__str__` 方法中使用 `format_node` 作为内部帮助程序,并且 2) `self` 是一个 placeholder Node,则返回 `None`。否则,返回当前 Node 的描述性字符串表示。
在 `Graph` 的 `__str__` 方法中使用 `format_node` 作为内部帮助程序,并且 `self` 是一个 placeholder Node,则返回 `None`。否则,返回当前 Node 的描述性字符串表示。
- 返回类型
注意
此 API 的向后兼容性已得到保证。
- insert_arg(idx, arg)[source]#
将 positional 参数插入到具有给定索引的参数列表中。
- 参数
idx (int) – `self.args` 中要插入的元素索引。
arg (Argument) – 要插入到 `args` 中的新参数值。
注意
此 API 的向后兼容性已得到保证。
- is_impure(impure_random=True)[source]#
返回此 op 是否为 impure,即它的 op 是否是 placeholder 或 output,或者是一个 impure 的 call_function 或 call_module。
警告
此 API 是实验性的,**不**向后兼容。
- property kwargs: dict[str, Union[tuple['Argument', ...], collections.abc.Sequence['Argument'], collections.abc.Mapping[str, 'Argument'], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]#
此
Node
的关键字参数字典。参数的解释取决于节点的 opcode。有关更多信息,请参阅Node
文档字符串。允许对该属性进行赋值。所有使用和用户的计数将在赋值时自动更新。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source]#
返回规范化后的 Python 目标参数。这意味着 args/kwargs 将与模块/函数的签名匹配,并且如果 normalize_to_only_use_kwargs 为 true,则仅以位置顺序返回 kwargs。还会填充默认值。不支持仅位置参数或可变参数。
支持模块调用。
可能需要 arg_types 和 kwarg_types 来消除重载歧义。
- 参数
root (torch.nn.Module) – 用于解析模块目标的模块。
arg_types (Optional[Tuple[Any]]) – args 的类型元组
kwarg_types (Optional[Dict[str, Any]]) – kwargs 的类型字典
normalize_to_only_use_kwargs (bool) – 是否规范化为仅使用 kwargs。
- 返回
返回 NamedTuple ArgsKwargsPair,如果成功则返回 None。
- 返回类型
Optional[ArgsKwargsPair]
警告
此 API 是实验性的,**不**向后兼容。
- prepend(x)[source]#
在图的节点列表中在此节点之前插入 x。示例
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- 参数
x (Node) – 要放在此节点之前的节点。必须是同一图的成员。
注意
此 API 的向后兼容性已得到保证。
- replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[source]#
将图中的
self
的所有使用替换为节点replace_with
。- 参数
- 返回
已在此上进行更改的节点列表。
- 返回类型
list[‘Node’]
注意
此 API 的向后兼容性已得到保证。
- replace_input_with(old_input, new_input)[source]#
遍历
self
的输入节点,并将old_input
的所有实例替换为new_input
。注意
此 API 的向后兼容性已得到保证。
- property stack_trace: Optional[str]#
返回跟踪期间记录的 Python 堆栈跟踪(如果存在)。当使用 fx.Tracer 进行跟踪时,此属性通常由 Tracer.create_proxy 填充。要在跟踪期间记录堆栈跟踪以进行调试,请将 Tracer 实例上的 record_stack_traces 设置为 True。当使用 dynamo 进行跟踪时,此属性将默认由 OutputGraph.create_proxy 填充。
stack_trace 的末尾将是最近的帧。
- class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source]#
Tracer
是实现torch.fx.symbolic_trace
符号跟踪功能的类。调用symbolic_trace(m)
等同于Tracer().trace(m)
。Tracer 可以被子类化以覆盖跟踪过程的各种行为。可覆盖的不同行为在此类的文档字符串的方法中进行了描述。
注意
此 API 的向后兼容性已得到保证。
- call_module(m, forward, args, kwargs)[source]#
指定此
Tracer
在遇到对nn.Module
实例的调用时的行为的方法。默认情况下,行为是通过
is_leaf_module
检查被调用模块是否为叶子模块。如果是,则在Graph
中发出一个引用m
的call_module
节点。否则,将正常调用Module
,跟踪其forward
函数中的操作。可以通过覆盖此方法来实现,例如,创建嵌套的跟踪 GraphModules,或在跟踪跨越
Module
边界时实现任何其他行为。- 参数
m (Module) – 正在发出调用的模块
forward (Callable) – 要调用的
Module
的 forward() 方法args (Tuple) – 模块调用点的 args
kwargs (Dict) – 模块调用点的 kwargs
- 返回
模块调用的返回值。如果发出了
call_module
节点,则为Proxy
值。否则,它将是Module
调用返回的任何值。- 返回类型
注意
此 API 的向后兼容性已得到保证。
- create_arg(a)[source]#
用于指定跟踪在准备值作为
Graph
中节点的参数时行为的方法。默认情况下,行为包括
迭代集合类型(例如 tuple、list、dict)并递归调用
create_args
处理元素。给定 Proxy 对象,返回对底层 IR
Node
的引用给定非 Proxy Tensor 对象,发出各种情况的 IR
对于 Parameter,发出一个引用该 Parameter 的
get_attr
节点对于非 Parameter Tensor,将 Tensor 存储在一个特殊的属性中,引用该属性。
可以通过覆盖此方法来支持更多类型。
- 参数
a (Any) – 要作为
Graph
中的Argument
发出的值。- 返回
值
a
已转换为相应的Argument
- 返回类型
Argument
注意
此 API 的向后兼容性已得到保证。
- create_args_for_root(root_fn, is_module, concrete_args=None)[source]#
创建与
root
模块签名对应的placeholder
节点。此方法内省 root 的签名并相应地发出这些节点,还支持*args
和**kwargs
。警告
此 API 是实验性的,**不**向后兼容。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]#
给定目标、参数、关键字参数和名称插入一个图节点。
可以覆盖此方法以执行额外的检查、验证或修改用于节点创建的值。例如,有人可能希望不允许记录就地操作。
注意
此 API 的向后兼容性已得到保证。
- 返回类型
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]#
从给定参数创建节点,然后将节点包装在 Proxy 对象中返回。
如果 kind = ‘placeholder’,那么我们正在创建一个表示函数参数的节点。如果我们确实需要编码默认参数,我们将使用
args
元组。对于placeholder
节点,args
否则为空。注意
此 API 的向后兼容性已得到保证。
- getattr(attr, attr_val, parameter_proxy_cache)[source]#
指定此
Tracer
在我们对nn.Module
实例调用 getattr 时的行为的方法。默认情况下,行为是为属性返回一个代理值。它还将代理值存储在
parameter_proxy_cache
中,以便将来的调用可以重用该代理而不是创建新代理。可以通过覆盖此方法来实现,例如,在查询参数时不返回代理。
- 参数
- 返回
getattr 调用的返回值。
警告
此 API 是实验性的,**不**向后兼容。
- is_leaf_module(m, module_qualified_name)[source]#
用于指定给定的
nn.Module
是否为“叶子”模块的方法。叶子模块是 IR 中出现的原子单元,由
call_module
调用引用。默认情况下,PyTorch 标准库命名空间 (torch.nn) 中的模块是叶子模块。所有其他模块都会被跟踪并通过,其组成操作会被记录,除非通过此参数另有指定。- 参数
- 返回类型
注意
此 API 的向后兼容性已得到保证。
- iter(obj)[source]#
- 在代理对象被迭代时调用,例如
在控制流中使用时。通常我们不知道该怎么做,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个迭代器。
注意
此 API 的向后兼容性已得到保证。
- 返回类型
- keys(obj)[source]#
- 在代理对象调用 keys() 方法时调用。
这就是在代理上调用 ** 时发生的情况。如果 ** 在您的自定义跟踪器中工作,它应该返回一个迭代器。
注意
此 API 的向后兼容性已得到保证。
- 返回类型
- path_of_module(mod)[source]#
辅助方法,用于在
root
的模块层级结构中查找mod
的限定名称。例如,如果root
有一个名为foo
的子模块,它有一个名为bar
的子模块,将bar
传递给此函数将返回字符串 “foo.bar”。注意
此 API 的向后兼容性已得到保证。
- to_bool(obj)[source]#
- 在代理对象转换为布尔值时调用,例如
在控制流中使用时。通常我们不知道该怎么做,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个值。
注意
此 API 的向后兼容性已得到保证。
- 返回类型
- class torch.fx.Proxy(node, tracer=None)[source]#
Proxy
对象是符号跟踪过程中使用的Node
包装器,它们会将遇到的所有操作(torch
函数调用、方法调用、运算符)记录到不断增长的 FX Graph 中。如果您正在进行图变换,您可以将自己的
Proxy
方法包装到原始Node
上,以便使用重载运算符将其他内容添加到Graph
中。Proxy
对象不可迭代。换句话说,如果Proxy
在循环或作为*args
/**kwargs
函数参数中使用,符号跟踪器将抛出错误。有两种主要方法可以解决这个问题:1. 将不可跟踪的逻辑分解为顶级函数,并对其使用
fx.wrap
。 2. 如果控制流是静态的(即循环次数基于某些超参数),则可以将代码保留在其原始位置,并重构为类似for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
有关 Proxy 内部机制的更详细说明,请参阅 torch/fx/README.md 中的“Proxy”部分。
注意
此 API 的向后兼容性已得到保证。
- class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[source]#
Interpreter 按节点逐个执行 FX 图。此模式可用于多种用途,包括编写代码转换以及分析传递。
可以通过覆盖 Interpreter 类中的方法来定制执行行为。方法覆盖的映射(按调用层级)
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
示例
假设我们想要将所有的
torch.neg
实例替换为torch.sigmoid
,反之亦然(包括它们对应的Tensor
方法)。我们可以像这样子类化 Interpreter:class NegSigmSwapInterpreter(Interpreter): def call_function( self, target: Target, args: Tuple, kwargs: Dict ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid())
- 参数
module (torch.nn.Module) – 要执行的模块
garbage_collect_values (bool) – 是否在模块执行期间在其最后一次使用后删除值。这可确保执行期间的最佳内存使用。例如,可以通过查看
Interpreter.env
属性来禁用此功能,以便检查执行中的所有中间值。graph (Optional[Graph]) – 如果提供了此参数,则解释器将执行此图而不是 module.graph,并使用提供的 module 参数来满足任何状态请求。
注意
此 API 的向后兼容性已得到保证。
- boxed_run(args_list)[source]#
通过解释来运行 module 并返回结果。这使用了“boxed”调用约定,即您传递一个参数列表,该列表将被解释器清除。这可确保输入张量被及时释放。
注意
此 API 的向后兼容性已得到保证。
- call_function(target, args, kwargs)[source]#
执行
call_function
节点并返回结果。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node。
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型
- Return
Any: 函数调用返回的值
注意
此 API 的向后兼容性已得到保证。
- call_method(target, args, kwargs)[source]#
执行
call_method
节点并返回结果。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node。
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型
- Return
Any: 方法调用返回的值
注意
此 API 的向后兼容性已得到保证。
- call_module(target, args, kwargs)[source]#
执行
call_module
节点并返回结果。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node。
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型
- Return
Any: 模块调用返回的值
注意
此 API 的向后兼容性已得到保证。
- fetch_args_kwargs_from_env(n)[source]#
从当前执行环境中获取节点
n
的具体args
和kwargs
值。- 参数
n (Node) – 需要从中获取
args
和kwargs
的节点。- 返回
args
和kwargs
包含n
的具体值。- 返回类型
Tuple[Tuple, Dict]
注意
此 API 的向后兼容性已得到保证。
- fetch_attr(target)[source]#
从
self.module
的Module
层级结构中获取一个属性。- 参数
target (str) – 要获取的属性的完全限定名
- 返回
属性的值。
- 返回类型
任何
注意
此 API 的向后兼容性已得到保证。
- get_attr(target, args, kwargs)[source]#
执行
get_attr
节点。将从self.module
的Module
层级结构中检索属性值。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node。
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回
检索到的属性值
- 返回类型
任何
注意
此 API 的向后兼容性已得到保证。
- map_nodes_to_values(args, n)[source]#
递归地遍历
args
,并在当前执行环境中查找每个Node
的具体值。- 参数
args (Argument) – 要在其中查找具体值的数据结构
n (Node) –
args
所属的节点。这仅用于错误报告。
- 返回类型
Optional[Union[tuple[‘Argument’, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
此 API 的向后兼容性已得到保证。
- output(target, args, kwargs)[source]#
执行
output
节点。这实际上只是检索output
节点引用的值并返回它。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node。
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回
output 节点引用的返回值
- 返回类型
任何
注意
此 API 的向后兼容性已得到保证。
- placeholder(target, args, kwargs)[source]#
执行
placeholder
节点。请注意,这是有状态的:Interpreter
维护一个指向传递给run
的参数的内部迭代器,而此方法返回该迭代器的next()
。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node。
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回
检索到的参数值。
- 返回类型
任何
注意
此 API 的向后兼容性已得到保证。
- run(*args, initial_env=None, enable_io_processing=True)[source]#
通过解释来运行 module 并返回结果。
- 参数
*args – 运行模块的参数,按位置顺序排列
initial_env (Optional[Dict[Node, Any]]) – 执行的可选起始环境。这是一个将 Node 映射到任何值的字典。例如,这可用于预填充某些 Nodes 的结果,以便仅在解释器中进行部分求值。
enable_io_processing (bool) – 如果为 True,我们首先使用图的
process_inputs
和process_outputs
函数来处理输入和输出,然后再使用它们。
- 返回
执行模块返回的值
- 返回类型
任何
注意
此 API 的向后兼容性已得到保证。
- class torch.fx.Transformer(module)[source]#
Transformer
是一种特殊的解释器,它生成一个新的Module
。它公开了一个transform()
方法,该方法返回转换后的Module
。Transformer
不需要参数即可运行,不像Interpreter
。Transformer
完全以符号方式工作。示例
假设我们想要将所有的
torch.neg
实例替换为torch.sigmoid
,反之亦然(包括它们对应的Tensor
方法)。我们可以像这样子类化Transformer
:class NegSigmSwapXformer(Transformer): def call_function( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any], ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any], ) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- 参数
module (GraphModule) – 要转换的
Module
。
注意
此 API 的向后兼容性已得到保证。
- get_attr(target, args, kwargs)[source]#
执行
get_attr
节点。在Transformer
中,此方法被覆盖,以将新的get_attr
节点插入到输出图中。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node。
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型
注意
此 API 的向后兼容性已得到保证。
- placeholder(target, args, kwargs)[source]#
执行
placeholder
节点。在Transformer
中,此方法被覆盖,以将新的placeholder
插入到输出图中。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参阅 Node。
args (Tuple) – 此调用位置参数的元组
kwargs (Dict) – 此调用关键字参数的字典
- 返回类型
注意
此 API 的向后兼容性已得到保证。
- torch.fx.replace_pattern(gm, pattern, replacement)[source]#
匹配
gm
的 Graph 中的所有可能的非重叠运算符及其数据依赖关系(pattern
),然后用另一个子图(replacement
)替换每个匹配的子图。- 参数
gm (GraphModule) – 包装 Graph 以进行操作的 GraphModule
pattern (Union[Callable, GraphModule]) – 要在
gm
中匹配以进行替换的子图replacement (Union[Callable, GraphModule]) – 用于替换
pattern
的子图
- 返回
一个
Match
对象列表,表示原始图中pattern
匹配到的位置。如果没有任何匹配项,则列表为空。Match
定义为class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
- 返回类型
List[Match]
示例
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]) def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上面的代码将首先在
traced_module
的forward
方法中匹配pattern
。模式匹配是基于使用-定义关系进行的,而不是节点名称。例如,如果在pattern
中有p = torch.cat([a, b])
,您可以在原始forward
函数中匹配m = torch.cat([a, b])
,即使变量名不同(p
对m
)。pattern
中的return
语句是基于其值进行匹配的;它可能匹配也可能不匹配较大图中的return
语句。换句话说,模式不必扩展到整个较大图的末尾。当模式匹配成功时,它将从较大函数中移除,并由
replacement
替换。如果在较大函数中有多个pattern
的匹配项,每个非重叠的匹配项都将被替换。如果存在重叠匹配,则将替换重叠匹配项中找到的第一个匹配项。“第一个”在此定义为拓扑排序的节点使用-定义关系中的第一个节点。在大多数情况下,第一个节点是紧跟在self
之后的参数,而最后一个节点是函数返回的值。需要注意的一个重要事项是,
pattern
可调用对象的参数必须在可调用对象本身中使用,并且replacement
可调用对象的参数必须与模式匹配。第一个规则解释了为什么在上面的代码块中,forward
函数具有参数x, w1, w2
,但pattern
函数仅具有参数w1, w2
。pattern
不使用x
,因此不应将其指定为参数。作为第二个规则的示例,考虑用def pattern(x, y): return torch.neg(x) + torch.relu(y)
替换
def replacement(x, y): return torch.relu(x)
在这种情况下,
replacement
需要与pattern
相同数量的参数(x
和y
),即使参数y
在replacement
中未使用。调用
subgraph_rewriter.replace_pattern
后,生成的 Python 代码如下所示:def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
此 API 的向后兼容性已得到保证。