评价此页

torch.export IR 规范#

创建于: 2023年10月05日 | 最后更新于: 2025年06月13日

Export IR 是一个用于编译器的中间表示 (IR),它与 MLIR 和 TorchScript 类似。它专门用于表达 PyTorch 程序的语义。Export IR 主要以简化的操作列表来表示计算,对控制流等动态性支持有限。

要创建 Export IR 图,可以使用前端,通过一个跟踪特殊化机制来可靠地捕获 PyTorch 程序。生成的 Export IR 随后可以由后端进行优化和执行。目前可以通过 torch.export.export() 来实现这一点。

本文档将涵盖的关键概念包括:

  • ExportedProgram:包含 Export IR 程序的的数据结构

  • Graph:由节点列表组成。

  • Nodes:代表操作、控制流以及存储在该节点上的元数据。

  • 值由节点生成和消耗。

  • 类型与值和节点相关联。

  • 还定义了值的尺寸和内存布局。

假设#

本文档假设读者已充分熟悉 PyTorch,特别是 torch.fx 及其相关工具。因此,它将不再描述 torch.fx 文档和论文中已包含的内容。

什么是 Export IR#

Export IR 是 PyTorch 程序的基于图的中间表示 IR。Export IR 实现于 torch.fx.Graph 之上。换句话说,**所有 Export IR 图也是有效的 FX 图**,如果使用标准的 FX 语义进行解释,Export IR 可以被可靠地解释。一个隐含的结论是,通过标准的 FX 代码生成,导出的图可以被转换为有效的 Python 程序。

本文档将主要关注 Export IR 与 FX 在严格性方面的差异,并跳过它们共享的相似部分。

ExportedProgram#

顶级的 Export IR 构造是 torch.export.ExportedProgram 类。它将 PyTorch 模型(通常是 torch.nn.Module)的计算图与该模型消耗的参数或权重捆绑在一起。

torch.export.ExportedProgram 类的一些值得注意的属性包括:

  • graph_moduletorch.fx.GraphModule):包含 PyTorch 模型展平计算图的数据结构。可以通过 ExportedProgram.graph 直接访问该图。

  • graph_signaturetorch.export.ExportGraphSignature):图签名,它指定了图中使用的参数和缓冲区名称以及被修改的参数和缓冲区。它不是将参数和缓冲区存储为图的属性,而是将它们提升为图的输入。graph_signature 用于跟踪这些参数和缓冲区上的附加信息。

  • state_dictDict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含参数和缓冲区的的数据结构。

  • range_constraintsDict[sympy.Symbol, RangeConstraint]):对于具有数据依赖行为导出的程序,每个节点上的元数据将包含符号形状(看起来像 s0i0)。此属性将符号形状映射到它们的下限/上限范围。

Graph#

Export IR Graph 是以 DAG(有向无环图)形式表示的 PyTorch 程序。图中的每个节点代表一个特定的计算或操作,图的边由节点之间的引用组成。

我们可以将 Graph 看作具有以下模式:

class Graph:
  nodes: List[Node]

在实践中,Export IR 的图是通过 torch.fx.Graph Python 类实现的。

Export IR 图包含以下节点(节点将在下一节更详细地描述):

  • 0 个或多个 placeholder 类型的节点

  • 0 个或多个 call_function 类型的节点

  • 恰好 1 个 output 类型的节点

推论: 最小的有效 Graph 将是单个节点。即节点永远不会为空。

定义: Graph 的 placeholder 节点集代表 GraphModule 的**输入**。Graph 的 output 节点代表 GraphModule 的**输出**。

示例

import torch
from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
  %x : [num_users=1] = placeholder[target=x]
  %y : [num_users=1] = placeholder[target=y]
  %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
  return (add,)

以上是 Graph 的文本表示,每一行代表一个节点。

Node#

Node 代表一个特定的计算或操作,并使用 torch.fx.Node 类在 Python 中表示。节点之间的边通过 Node 类的 args 属性直接表示为对其他节点的引用。使用相同的 FX 机制,我们可以表示计算图通常需要的以下操作,例如操作调用、占位符(也称为输入)、条件和循环。

Node 具有以下模式:

class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文本格式

如上例所示,请注意,每行都遵循以下格式:

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以紧凑的方式捕获了 Node 类中的所有内容,meta 除外。

具体来说:

  • <name> 是节点在 node.name 中出现的名称。

  • <op_name>node.op 字段,必须是以下之一:<call_function><placeholder><get_attr><output>

  • <target> 是节点作为 node.target 的目标。此字段的含义取决于 op_name

  • args1, … args 4…node.args 元组中列出的内容。如果列表中的值为 torch.fx.Node,则会以前导的 % 特别指示。

例如,对 add 运算符的调用将显示为

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x%y 是另外两个名为 x 和 y 的节点。值得注意的是,字符串 torch.op.aten.add.Tensor 代表实际存储在 target 字段中的可调用对象,而不仅仅是其字符串名称。

这种文本格式的最后一行是

return [add]

这是一个 op_name = output 的节点,表示我们正在返回该元素。

call_function#

一个 call_function 节点表示对运算符的调用。

定义

  • 函数式: 我们说一个可调用对象是“函数式”的,如果它满足以下所有要求:

    • 非变异:运算符不会改变其输入的 D值(对于张量,这包括元数据和数据)。

    • 无副作用:运算符不会改变从外部可见的状态,例如更改模块参数的值。

  • 运算符: 是具有预定义模式的函数式可调用对象。此类运算符的示例包括函数式 ATen 运算符。

在 FX 中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

与普通 FX call_function 的区别

  1. 在 FX 图中,call_function 可以引用任何可调用对象,而在 Export IR 中,我们将其限制为仅选择一部分 ATen 运算符、自定义运算符和控制流运算符。

  2. 在 Export IR 中,常量参数将嵌入到图中。

  3. 在 FX 图中,get_attr 节点可以表示读取图模块中存储的任何属性。然而,在 Export IR 中,这被限制为仅读取子模块,因为所有参数/缓冲区都将作为输入传递给图模块。

元数据#

Node.meta 是附加到每个 FX 节点的字典。但是,FX 规范并未指定哪些元数据可能存在或将会存在。Export IR 提供了更强的约定,特别是所有 call_function 节点都将保证具有且仅具有以下元数据字段:

  • node.meta["stack_trace"] 是一个字符串,包含引用原始 Python 源代码的 Python 堆栈跟踪。堆栈跟踪示例看起来像

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"] 描述了运行操作的输出。它可以是 <symint><FakeTensor>List[Union[FakeTensor, SymInt]]None 类型。

  • node.meta["nn_module_stack"] 描述了节点来自的 torch.nn.Module 的“堆栈跟踪”,如果它来自 torch.nn.Module 调用。例如,如果一个包含 addmm 运算符的节点是从 torch.nn.Linear 模块内部的 torch.nn.Sequential 模块调用的,则 nn_module_stack 将如下所示:

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"] 包含在分解之前调用该节点的 torch 函数或叶子 torch.nn.Module 类。例如,一个包含来自 torch.nn.Linear 模块调用的 addmm 运算符的节点将在其 source_fn 中包含 torch.nn.Linear,而一个包含来自 torch.nn.functional.Linear 模块调用的 addmm 运算符的节点将在其 source_fn 中包含 torch.nn.functional.Linear

placeholder#

Placeholder 代表图的输入。其语义与 FX 中的完全相同。Placeholder 节点必须是图中节点列表的前 N 个节点。N 可以为零。

在 FX 中的表示

%name = placeholder[target = name](args = ())

target 字段是输入名称的字符串。

args(如果非空)应大小为 1,表示此输入的默认值。

元数据

Placeholder 节点也具有 meta[‘val’],就像 call_function 节点一样。在这种情况下,val 字段表示图在编译时预期接收的该输入的形状/dtype。

output#

输出调用代表函数中的 return 语句;因此,它终止了当前图。只有一个输出节点,并且它将始终是图的最后一个节点。

在 FX 中的表示

output[](args = (%something, …))

这与 torch.fx 中的语义完全相同。args 表示要返回的节点。

元数据

输出节点的元数据与 call_function 节点相同。

get_attr#

get_attr 节点表示从封装的 torch.fx.GraphModule 读取子模块。与 torch.fx.symbolic_trace() 的普通 FX 图不同,在普通 FX 图中 get_attr 节点用于从顶层 torch.fx.GraphModule 读取参数和缓冲区等属性,在 Export IR 中,参数和缓冲区作为输入传递给图模块,并存储在顶层 torch.export.ExportedProgram 中。

在 FX 中的表示

%name = get_attr[target = name](args = ())

示例

考虑以下模型

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 读取包含 sin 运算符的子模块 true_graph_0

参考文献#

SymInt#

SymInt 是一个对象,它可以是字面整数,也可以是表示整数的符号(在 Python 中用 sympy.Symbol 类表示)。当 SymInt 是符号时,它描述了一个在编译时对图未知的整数类型变量,也就是说,它的值仅在运行时才知道。

FakeTensor#

FakeTensor 是一个包含张量元数据quoi的对象。它可以被视为具有以下元数据。

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor 的 size 字段是整数或 SymInts 的列表。如果存在 SymInts,则表示此张量具有动态形状。如果存在整数,则假定张量将具有该确切的静态形状。TensorMeta 的秩永远不是动态的。dtype 字段表示该节点输出的 dtype。Edge IR 中没有隐式类型提升。FakeTensor 中没有 strides。

换句话说

  • 如果 node.target 中的运算符返回一个 Tensor,则 node.meta['val'] 是一个描述该张量的 FakeTensor。

  • 如果 node.target 中的运算符返回一个 n 元组的 Tensor,则 node.meta['val'] 是一个描述每个张量的 n 元组的 FakeTensors。

  • 如果 node.target 中的运算符返回一个在编译时已知的 int/float/scalar,则 node.meta['val'] 为 None。

  • 如果 node.target 中的运算符返回一个在编译时未知的 int/float/scalar,则 node.meta['val'] 的类型为 SymInt。

例如

  • aten::add 返回一个 Tensor;因此,其规范将是描述该运算符返回的张量的 dtype 和大小的 FakeTensor。

  • aten::sym_size 返回一个整数;因此,其 val 将是 SymInt,因为其值仅在运行时可用。

  • max_pool2d_with_indexes 返回一个(Tensor,Tensor)元组;因此,规范也将是一个 FakeTensor 对象的 2 元组,第一个 TensorMeta 描述返回值的第一个元素,依此类推。

Python 代码

def add_one(x):
  return torch.ops.aten(x, 1)

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree-able 类型#

我们将一种类型定义为“Pytree-able”,如果它是一个叶子类型或一个包含其他 Pytree-able 类型的容器类型。

注意

Pytree 的概念与 JAX 的文档 此处 记录的概念相同。

以下类型定义为叶子类型

类型

定义

张量

torch.Tensor

Scalar

Python 中的任何数值类型,包括整数类型、浮点类型和零维张量。

int

Python int(在 C++ 中绑定为 int64_t)

浮点数

Python float(在 C++ 中绑定为 double)

布尔值

Python bool

str

Python string

ScalarType

torch.dtype

Layout

torch.layout

MemoryFormat

torch.memory_format

设备

torch.device

以下类型定义为容器类型

类型

定义

Tuple

Python tuple

List

Python list

Dict

键为 Scalar 的 Python dict

NamedTuple

Python namedtuple

Dataclass

必须通过 register_dataclass 注册

Custom class

通过 _register_pytree_node 定义的任何自定义类