评价此页

在 ATen IR 上编写图变换#

创建日期:2025 年 6 月 11 日 | 最后更新日期:2025 年 6 月 11 日

Passes#

由于 ATen IR 位于 FX Graph/GraphModule 层,因此为 FX Graph 编写的任何变换都可以轻松应用于 ATen IR。如果您熟悉编写 FX 图变换,那么这与您已经了解的相同。

编写变换的最直接方法是遍历给定的图并直接操作图中的节点。

例如,假设我们要将 torch.ops.aten.add.Tensor() 调用替换为 torch.ops.aten.mul.Tensor() 调用

import torch

def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
            node.target = torch.ops.aten.mul.Tensor

我们还可以通过 FX 实用函数删除和添加新节点,这些函数可以在 Graph 文档中找到。例如,如果我们想在 add 调用之后插入一个 torch.ops.aten.relu.default()

import torch

def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:

            # Specifies the insertion point. Any nodes added to the graph within
            # this scope will be inserted after `node`
            with gm.graph.inserting_after(node):
                # Insert a new `call_function` node with op `torch.ops.aten.relu.default`
                new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
                # Replace all the places that use `node` to now use the `new_relu_node`
                node.replace_all_uses_with(new_relu_node)

总的来说,变换可以大致分为几个轴

轴 A:1. 创建一对多映射(例如,分解) 2. 创建多对一映射(例如,融合)

轴 B:1. 正向迭代(例如,形状传播) 2. 反向迭代(例如,死代码消除)

轴 C:1. 依赖于局部节点信息(例如,out-variant 转换) 2. 依赖于全局图信息(例如,内存规划)

我们对这些用例频率的预测是:1. A.1,B.1,C.1 2. A.2 3. B.2,C.2

虽然我们可以通过直接操作图来实现所有图变换,但我们也提供了一些辅助工具,以便在处理 1 级和 2 级用例时更加方便。

Transformer#

对于 1 级用例(创建一对多映射、进行正向迭代和查看局部节点信息),我们可以利用 Transformer 类来执行每个节点并重新创建一个图,但会应用指定的变换。

一对一 Pass#

对于一对一映射的示例,如果我们想用另一个 op B 替换 op A,我们可以运行 GraphModule,每次看到 op A 时,返回 op B。

一个例子是

class ReplaceAddWithMul(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)
        return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)

transformed_graph_module = ReplaceAddWithMul(graph_module).transform()

super().call_function(target, args, kwargs, meta) 的调用会创建一个 call_function FX 节点,并返回使用给定参数运行操作的结果。

一对多 Pass#

如果我们想进行一对多映射,例如用两个其他 op B 和 C 替换 op A,那么我们将调用两次 super().call_function 来创建两个 FX 节点,一个 op B,另一个 op C,并返回运行 op C 的结果。

例如

class ReplaceAddWithMulSub(torch.fx.Transformer):
    """
    Original:
        def f(x, y):
            return x + y

    After pass:
        def f(x, y):
            z = x * y
            return z - y
    """
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)

        x, y = args

        mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
        return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})

transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()

一对零 Pass#

如果我们想移除一个 op,我们可以直接返回传递给函数的该值

class RemoveDetachPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target not in (
            torch.ops.aten.detach.default,
            torch.ops.aten.detach_copy.default,
        ):
            return super().call_function(target, args, kwargs, meta)

        assert len(args) == 1
        return args[0]

transformed_graph_module = RemoveDetachPass(graph_module).transform()

利用局部信息#

利用局部节点信息的例子是,如果我们想将图中的所有标量转换为张量,我们可以运行给定的 fx.GraphModule,并且对于包含标量的每个参数,我们将其转换为张量。它可能看起来像

def args_map(target, fn, args, kwargs):
    assert isinstance(args, tuple)
    assert isinstance(kwargs, dict)
    args = list(args)
    kwargs = kwargs.copy()

    # Update the argument based on the function passed
    def update(key, args, schema):
        args[key] = fn(args[key], schema)

    # Update each argument in the schema
    for i, schema in enumerate(target._schema.arguments):
        if schema.name in kwargs:
            update(schema.name, kwargs, schema)
        elif not schema.kwarg_only and i < len(args):
            update(i, args, schema)
    return tuple(args), kwargs

class ScalarToTensorPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        breakpoint()
        def try_coerce(value, arg):
            return (
                torch.tensor(value)
                if isinstance(value, (float, int, bool))
                and type(arg.type) == torch.TensorType
                else value
            )

        args, kwargs = args_map(target, try_coerce, args, kwargs)
        return super().call_function(target, args, kwargs)

transformed_graph_module = ScalarToTensorPass(graph_module).transform()

子图重写器#

要创建多对一映射,我们可以利用 FX 的 subgraph rewriter。给定一个 pattern,它会创建一个匹配该模式的操作子图,然后将每个匹配的子图替换为 replacement

注意

This is an inplace operation.

patternreplacement 输入必须是可调用函数或包含与图中使用(ATen ops)的相同操作的 GraphModules,这样子图重写器才能在图中找到正确的模式。模式/替换可调用对象的输入在匹配时将被视为通配符。

一个例子

from torch.fx import subgraph_rewriter

def replace_patterns(graph_module):
    def pattern(x, y):
        x = torch.ops.aten.add.Tensor(x, y)
        x = torch.ops.aten.mul.Tensor(x, y)
        return x

    def replacement(x, y):
        return torch.ops.aten.sub.Tensor(x, y)

replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
    traced_module, pattern, replacement
)

子图重写器返回一个 ReplacedPatterns 列表

@dataclass
class ReplacedPatterns:
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]
    # List of nodes that were added into the graph
    replacements: List[Node]

注意

The nodes created by the subgraph rewriter will not have the metadata that
is populated in the matched nodes, but you can use
`ReplacedPatterns.nodes_map` to find the nodes in the original graph that
were matched, and `ReplacedPatterns.replacements` to find the nodes that
were replaced in the transformed graph.

Pass Manager#

PassManager 是一个用于在给定图模块上运行多个 pass 的类。在初始化 PassManager 实例时,我们会传入一个我们想要运行的 pass 列表,并设置几个标志。要在图模块上运行 pass 集合,我们可以直接将图模块传递给 PassManager 实例。

一个例子

from torch.fx.passes.infra.pass_manager import PassManager

pm = PassManager(
    passes=[replace_add_with_div, replace_div_with_mul],
    run_checks_after_each_pass=True,
    suppress_check_failures=False,
)
graph_module_out = pm(graph_module)

要添加一组常见的检查,在每个 pass 运行后进行,我们可以调用函数 set_checks(check: Callable),该函数接受一个可调用函数作为输入。如果设置了 run_checks_after_each_pass 标志,则在图模块上运行每个 pass 后将调用 check

一个例子

pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])

def check_div_target(graph_module):
    for node in graph_module.graph.nodes:
        if node.op == "call_function" and node.target != torch.div:
            raise ValueError("Target should be div!")

pm.add_checks(check_div_target)

pm(graph_module)    # raises ValueError after replace_div_with_mul pass

Partitioner#

我们可以使用几种常见的基于 FX 图的分区器来对图进行分区。

子图匹配器#

要查找图中与特定模式匹配的子图,我们可以利用 FX 的 SubgraphMatcher

类属性

  • pattern (Graph):目标匹配模式。图中的占位符节点在匹配时将被视为通配符。

  • match_output (bool):如果为 True,则模式图中的输出节点将被视为目标模式的一部分。如果为 False,则在匹配期间忽略输出节点。

  • match_placeholder (bool):如果为 True,则模式图中的占位符节点将被视为目标模式的一部分。如果为 False,则占位符节点将用作通配符。

  • remove_overlapping_matches (bool):如果为 True,在重叠匹配的情况下,将只返回第一个匹配项。

  • ignore_literals (bool):如果为 True,将不检查字面量是否相等,而是将它们视为通配符。

一个例子

from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

class LargeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight = torch.nn.Parameter(torch.ones(3, 3))
        self._bias = torch.nn.Parameter(torch.ones(3, 3))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias, x, self._weight)

large_model_graph = torch.export(LargeModel(), inputs).graph

class PatternModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
        self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)

pattern_graph = torch.export(PatternModel(), inputs).graph

subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)

match 函数返回一个 InternalMatch 列表

@dataclass
class InternalMatch():
    # Nodes from which the match was found
    anchors: List[Node]
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node] = field(default_factory=dict)
    # Nodes in target graph that are matched placeholder in pattern
    placeholder_nodes: List[Node] = field(default_factory=list)
    # Nodes in matched subgraph returned by output
    returning_nodes: List[Node] = field(default_factory=list)

基于能力的 Partitioner#

要查找支持特定不变性的最大节点子图,我们可以利用 FX 的 CapabilityBasedPartitioner

类属性

  • graph_module (torch.fx.GraphModule):我们正在对其进行分区的图模块。

  • operator_support (OperatorSupportBase):用于确定图中的节点是否受分区支持的对象。

  • allows_single_node_partition (bool):如果为 True,则允许形成单节点分区。

  • non_compute_ops (Optional[Sequence[str]]):一组被视为“非计算”的操作(例如 torch.ops.aten.view_operator.getitem),以便分区器不会创建仅包含这些非计算操作的图。

  • allowed_single_node_partition_ops (Optional[Sequence[str]]):允许在单节点分区中使用的操作集。

OperatorSupportBase 类由分区器用来确定图中的特定节点是否属于该分区。这是通过覆盖 is_node_supported 函数来完成的。您可以通过使用 chain(如果任何 OperatorSupportBase 返回 False,则返回 False)和 any_chain(如果任何 OperatorSupportBase 返回 True,则返回 True)来链接多个 OperatorSupportBase

一个例子

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddMulOperatorSupport(OperatorSupportBase):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        return node.op == "call_function" and node.target in [
            torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
        ]

capability_partitioner = CapabilityBasedPartitioner(
    graph_module,
    op_support,
)

# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph
fused_graph_module = capability_partitioner.fuse_partitions(partition_list)