• 文档 >
  • 编写 Dynamo ATen 降级通道
快捷方式

编写 Dynamo ATen 降级 Passes

降级 Pass 的基础

ATen 降级 Passes 是 Python 函数,它们接收一个 ATen 操作符图作为输入,应用一些期望的修改,例如操作符合并/融合、操作符替换、子图重写、自定义操作符插入,或其他对 torch.fx.GraphModule 的操作,然后将修改后的图返回给调用者。这些降级 Passes 通常会就地修改图并返回相同的输入对象。

降级 Pass 要求

Torch-TRT 中的 ATen 降级 Pass 函数必须满足两个要求: - 该函数必须接收一个 torch.fx.GraphModule 和一系列 torch Tensors Sequence[torch.Tensor] 作为输入,并返回降级后的 torch.fx.GraphModule - 该函数必须使图保持有效且可调用的状态,包括执行任何必要的 linting 和重新编译

有关 FX 中 图操作 的信息,请参阅此链接。下面是一个降级 Pass 的示例,该 Pass 修复了图中输入也是输出的图,这是 TRT Engines 不允许的配置。

降级 Pass 示例

def repair_input_as_output(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
    """Repair scenarios where inputs are also outputs of the graph

    TRT does not allow such cases, so we insert a clone (identity) layer
    """
    modified_graph = False

    # Extract graph placeholder Tensors
    placeholders = [
        node
        for node in gm.graph.nodes
        if (
            node.op == "placeholder"
            and isinstance(node.type, type)
            and issubclass(node.type, torch.Tensor)
        )
    ]

    for placeholder in placeholders:
        # If any placeholder has any users which are direct graph outputs
        if len(placeholder.users) >= 1 and any(
            user.op == "output" for user in placeholder.users
        ):
            modified_graph = True

            # Get direct graph outputs which are direct uses of placeholders
            direct_outputs = [user for user in placeholder.users if user.op == "output"]

            # Insert clone node for placeholder to ensure
            # placeholder is not a direct output
            with gm.graph.inserting_after(placeholder):
                cloned_placeholder = gm.graph.call_function(
                    torch.ops.aten.clone.default,
                    args=(placeholder,),
                )

            # Replace placeholder as output with cloned version
            for output in direct_outputs:
                output.replace_input_with(placeholder, cloned_placeholder)

    # If the graph was modified, clean up the graph and ensure it is up-to-date
    if modified_graph:
        gm.graph.eliminate_dead_code()
        gm.graph.lint()
        gm.recompile()
        logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")

    return gm

注册降级 Passes

降级 Passes 目前在 py/torch_tensorrt/dynamo/lowering/passes/__init__.py 中注册,使用 torch.fx.passes.pass_manager.PassManager 工具按期望的顺序组装 Pass 列表。直接添加到该列表中的新 Passes 将应用于 Torch-TensorRT torch.compile 后端的图。目前,我们提供了一个方便的 ATen 降级 Pass 注册装饰器,可以按原样调用,或使用可选的 index 关键字参数来控制降级 Pass 在 Pass 列表中的插入位置。

例如,要在默认位置(列表末尾)插入 Pass,可以使用以下代码

@_aten_lowering_pass
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
    ...

或者,要在 Pass 列表的自定义索引(例如列表开头)处插入 Pass,可以使用以下代码

@_aten_lowering_pass(index=0)
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
    ...

torch_tensorrt.dynamo.lowering.passes 中还提供了用于显示当前可用的降级 Pass 列表、将这些 Pass 应用于任意 torch.fx.GraphModule 以及删除特定索引处的降级 Pass 的实用工具。

# Print all lowering passes in the list
print(dump_lowering_passes())

# Apply lowering passes to a GraphModule
apply_lowering_passes(graph_module, sample_inputs)

# Remove the lowering pass at index 1
_remove_lowering_pass(index=1)

注意: 上述 API 可能会发生变化,因为降级 Pass 系统仍在发展中。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源