在ATen IR上编写图转换#
创建日期:2025 年 6 月 11 日 | 最后更新日期:2025 年 6 月 11 日
Passes(过程)#
由于ATen IR位于FX Graph/GraphModule级别,因此可以轻松地将为FX Graphs编写的任何转换应用于ATen IR。如果您熟悉编写FX graph转换,那么这与此相同。
编写转换的最直接方法是遍历给定图并直接操作图中的节点。
例如,假设我们要将 torch.ops.aten.add.Tensor()
调用替换为 torch.ops.aten.mul.Tensor()
调用。
import torch
def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.mul.Tensor
我们还可以通过FX实用程序函数删除和追加新节点,这些函数可以在 Graph 文档中找到。例如,如果我们想在 add
调用之后插入一个 torch.ops.aten.relu.default()
。
import torch
def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
# Specifies the insertion point. Any nodes added to the graph within
# this scope will be inserted after `node`
with gm.graph.inserting_after(node):
# Insert a new `call_function` node with op `torch.ops.aten.relu.default`
new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
# Replace all the places that use `node` to now use the `new_relu_node`
node.replace_all_uses_with(new_relu_node)
总的来说,转换可以大致分为几个轴
轴 A:1. 创建一对多映射(例如,分解)2. 创建多对一映射(例如,融合)
轴 B:1. 执行前向迭代(例如,形状传播)2. 执行后向迭代(例如,死代码消除)
轴 C:1. 依赖于局部节点信息(例如,out-variant转换)2. 依赖于全局图信息(例如,内存规划)
我们对这些用例的频率预测是:1. A.1、B.1、C.1 2. A.2 3. B.2、C.2
虽然我们可以通过直接操作图来完成所有图转换,但我们也提供了一些辅助实用程序,以方便使用第一级和第二级用例。
Transformer(转换器)#
对于第一级用例(创建一对多映射、执行前向迭代和查看局部节点信息),我们可以利用 Transformer 类来执行每个节点并重新创建图,但需要指定转换。
一对一 Pass(过程)#
以一对一映射为例,如果我们想用另一个op B替换op A,我们可以运行GraphModule,并且每次看到op A时,返回op B。
一个例子是
class ReplaceAddWithMul(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)
return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)
transformed_graph_module = ReplaceAddWithMul(graph_module).transform()
super().call_function(target, args, kwargs, meta)
调用会创建一个 call_function
FX节点,并返回使用给定参数运行操作的结果。
一对多 Pass(过程)#
如果我们想进行一对多映射,例如用2个op B和C替换op A,我们就可以调用2次 super().call_function
来创建2个FX节点,一个包含op B,另一个包含op C,并返回运行op C的结果。
例如
class ReplaceAddWithMulSub(torch.fx.Transformer):
"""
Original:
def f(x, y):
return x + y
After pass:
def f(x, y):
z = x * y
return z - y
"""
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)
x, y = args
mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})
transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()
一对零 Pass(过程)#
如果我们想删除一个op,我们只需返回传递给函数的value。
class RemoveDetachPass(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
if target not in (
torch.ops.aten.detach.default,
torch.ops.aten.detach_copy.default,
):
return super().call_function(target, args, kwargs, meta)
assert len(args) == 1
return args[0]
transformed_graph_module = RemoveDetachPass(graph_module).transform()
利用局部信息#
利用局部节点信息的例子是,如果我们想将图中的所有标量转换为张量,我们可以运行给定的 fx.GraphModule
,并且对于包含标量的每个参数,我们将其转换为张量。它可能看起来像
def args_map(target, fn, args, kwargs):
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
args = list(args)
kwargs = kwargs.copy()
# Update the argument based on the function passed
def update(key, args, schema):
args[key] = fn(args[key], schema)
# Update each argument in the schema
for i, schema in enumerate(target._schema.arguments):
if schema.name in kwargs:
update(schema.name, kwargs, schema)
elif not schema.kwarg_only and i < len(args):
update(i, args, schema)
return tuple(args), kwargs
class ScalarToTensorPass(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
breakpoint()
def try_coerce(value, arg):
return (
torch.tensor(value)
if isinstance(value, (float, int, bool))
and type(arg.type) == torch.TensorType
else value
)
args, kwargs = args_map(target, try_coerce, args, kwargs)
return super().call_function(target, args, kwargs)
transformed_graph_module = ScalarToTensorPass(graph_module).transform()
Subgraph Rewriter(子图重写器)#
为了创建多对一映射,我们可以利用FX的 subgraph rewriter(子图重写器) 。给定一个 pattern
,它会创建一个匹配该模式的操作子图,然后将每个匹配的子图替换为 replacement
。
注意
This is an inplace operation.
pattern
和 replacement
输入必须是可调用函数或GraphModules,它们包含图中所使用的相同操作(ATen ops),以便子图重写器可以在图中找到正确的模式。模式/替换可调用函数的输入将在匹配时被视为通配符。
一个例子
from torch.fx import subgraph_rewriter
def replace_patterns(graph_module):
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x
def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
traced_module, pattern, replacement
)
子图重写器返回一个 ReplacedPatterns
列表。
@dataclass
class ReplacedPatterns:
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
# List of nodes that were added into the graph
replacements: List[Node]
注意
The nodes created by the subgraph rewriter will not have the metadata that
is populated in the matched nodes, but you can use
`ReplacedPatterns.nodes_map` to find the nodes in the original graph that
were matched, and `ReplacedPatterns.replacements` to find the nodes that
were replaced in the transformed graph.
Pass Manager(过程管理器)#
PassManager(过程管理器) 是一个用于在给定图模块上运行多个过程的类。在初始化 PassManager
实例时,我们传入要运行的过程列表并设置一些标志。要对图模块运行过程集合,我们可以直接将图模块传递给 PassManager
实例。
一个例子
from torch.fx.passes.infra.pass_manager import PassManager
pm = PassManager(
passes=[replace_add_with_div, replace_div_with_mul],
run_checks_after_each_pass=True,
suppress_check_failures=False,
)
graph_module_out = pm(graph_module)
要添加一套常见的检查,在每次过程运行后进行,我们可以调用函数 set_checks(check: Callable)
,它接受一个可调用函数作为输入。如果设置了 run_checks_after_each_pass
标志,则在图模块上运行每个过程后将调用 check
。
一个例子
pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
def check_div_target(graph_module):
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target != torch.div:
raise ValueError("Target should be div!")
pm.add_checks(check_div_target)
pm(graph_module) # raises ValueError after replace_div_with_mul pass
Partitioner(分区器)#
我们可以使用几种常见的基于FX图的分区器来分区图。
Subgraph Matcher(子图匹配器)#
要查找图中与特定模式匹配的子图,我们可以利用FX的 SubgraphMatcher
。
类属性
pattern (Graph)
: 目标匹配模式。图中的占位符节点在匹配时将被视为通配符。match_output (bool)
: 如果为 True,则模式图中的输出节点将被视为目标模式的一部分。如果为 False,则在匹配过程中将忽略输出节点。match_placeholder (bool)
: 如果为 True,则模式图中的占位符节点将被视为目标模式的一部分。如果为 False,则占位符节点将被用作通配符。remove_overlapping_matches (bool)
: 如果为 True,在重叠匹配的情况下,将只返回第一个匹配。ignore_literals (bool)
: 如果为 True,则不检查字面量是否相等,而是将它们视为通配符。
一个例子
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
class LargeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight = torch.nn.Parameter(torch.ones(3, 3))
self._bias = torch.nn.Parameter(torch.ones(3, 3))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
large_model_graph = torch.export(LargeModel(), inputs).graph
class PatternModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
pattern_graph = torch.export(PatternModel(), inputs).graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)
match
函数返回一个 InternalMatch
列表。
@dataclass
class InternalMatch():
# Nodes from which the match was found
anchors: List[Node]
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] = field(default_factory=dict)
# Nodes in target graph that are matched placeholder in pattern
placeholder_nodes: List[Node] = field(default_factory=list)
# Nodes in matched subgraph returned by output
returning_nodes: List[Node] = field(default_factory=list)
基于能力的 Partition(分区)#
为了找到支持特定不变性的节点的最大子图,我们可以利用 FX 的 CapabilityBasedPartitioner
。
类属性
graph_module (torch.fx.GraphModule)
:我们正在对其进行分区图模块。operator_support (OperatorSupportBase)
:用于确定图中节点是否支持分区的对象。allows_single_node_partition (bool)
:如果为 True,则允许形成单节点分区。non_compute_ops (Optional[Sequence[str]])
:一组被认为是“非计算”的算子(例如torch.ops.aten.view
和_operator.getitem
),因此分区器不会创建仅包含这些非计算算子的图allowed_single_node_partition_ops (Optional[Sequence[str]])
:一组允许包含在单节点分区中的算子。
分区器使用 OperatorSupportBase
类来确定图中的特定节点是否属于分区。这是通过覆盖 is_node_supported
函数来实现的。您可以通过使用 chain
(如果任何 OperatorSupportBase 返回 False,则返回 False)和 any_chain
(如果任何 OperatorSupportBase 返回 True,则返回 True)来链接多个 OperatorSupportBase
。
一个例子
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
class AddMulOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
]
capability_partitioner = CapabilityBasedPartitioner(
graph_module,
op_support,
)
# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph
fused_graph_module = capability_partitioner.fuse_partitions(partition_list)