评价此页

torch.export IR 规范#

创建于:2023年10月05日 | 最后更新于:2025年07月16日

Export IR(导出中间表示)是用于编译器的中间表示(IR),它与MLIR和TorchScript有相似之处。它专门用于表达PyTorch程序的语义。Export IR主要以精简的操作列表形式表示计算,并对动态性(如控制流)提供有限的支持。

要创建Export IR图,可以使用前端,通过一个健全的、专用于追踪的机制来捕获PyTorch程序。生成的Export IR随后可以由后端进行优化和执行。今天,可以通过torch.export.export()来完成此操作。

本文档将涵盖的关键概念包括:

  • ExportedProgram:包含Export IR程序的数据结构

  • Graph:由节点列表组成。

  • Nodes:表示在此节点上存储的操作、控制流和元数据。

  • 值由节点产生和消费。

  • 类型与值和节点相关联。

  • 还定义了值的尺寸和内存布局。

假设#

本文档假设读者对PyTorch有充分的了解,特别是对torch.fx及其相关工具链有深入的了解。因此,它将不再描述torch.fx文档和论文中已有的内容。

什么是Export IR#

Export IR是PyTorch程序的基于图的中间表示IR。Export IR建立在torch.fx.Graph之上实现的。换句话说,**所有Export IR图也是有效的FX图**,如果使用标准的FX语义来解释,Export IR可以被健全地解释。一个隐含的含义是,通过标准的FX代码生成,导出的图可以转换为有效的Python程序。

本文档将主要关注Export IR在严格性方面与FX不同的地方,而跳过与FX相似的部分。

ExportedProgram#

顶级Export IR结构是torch.export.ExportedProgram类。它将PyTorch模型的计算图(通常是torch.nn.Module)与其消耗的参数或权重捆绑在一起。

torch.export.ExportedProgram类的一些值得注意的属性包括:

  • graph_moduletorch.fx.GraphModule):包含PyTorch模型扁平化计算图的数据结构。可以通过ExportedProgram.graph直接访问该图。

  • graph_signaturetorch.export.ExportGraphSignature):图的签名,它指定了图中使用的和修改的参数和缓冲区名称。参数和缓冲区没有存储为图的属性,而是被提升为图的输入。graph_signature用于跟踪这些参数和缓冲区的额外信息。

  • state_dictDict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含参数和缓冲区的数据结构。

  • range_constraintsDict[sympy.Symbol, RangeConstraint]):对于具有数据依赖行为导出的程序,每个节点的元数据将包含符号形状(看起来像s0, i0)。此属性将符号形状映射到它们的下限/上限范围。

#

Export IR图是以DAG(有向无环图)形式表示的PyTorch程序。图中的每个节点代表一个特定的计算或操作,图的边由节点之间的引用组成。

我们可以将图看作具有以下模式:

class Graph:
  nodes: List[Node]

实际上,Export IR的图是通过torch.fx.Graph Python类实现的。

Export IR图包含以下节点(节点将在下一节中更详细地描述):

  • 0个或多个placeholder类型的节点

  • 0个或多个call_function类型的节点

  • 正好1个output类型的节点

推论: 最小的有效图将是单个节点。即节点列表从不为空。

定义: 图的placeholder节点集合代表GraphModule的**输入**。图的output节点代表GraphModule的**输出**。

示例

import torch
from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
  %x : [num_users=1] = placeholder[target=x]
  %y : [num_users=1] = placeholder[target=y]
  %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
  return (add,)

以上是图的文本表示,每一行代表一个节点。

节点#

节点代表一个特定的计算或操作,在Python中使用torch.fx.Node类表示。节点之间的边通过Node类的args属性直接引用其他节点来表示。使用相同的FX机制,我们可以表示计算图通常需要的以下操作,例如操作调用、占位符(又名输入)、条件和循环。

Node具有以下模式:

class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文本格式

如上例所示,请注意,每一行都遵循此格式:

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以紧凑的方式捕获了Node类中的所有内容,除了meta字段。

具体来说:

  • 是节点在node.name中出现的名称。

  • <op_name>node.op字段,必须是以下之一:<call_function><placeholder><get_attr><output>

  • 是节点的目标,即node.target。此字段的含义取决于op_name

  • args1, … args 4…node.args元组中列出的内容。如果列表中的值是torch.fx.Node,则会通过前导的%进行特别指示。

例如,对add操作的调用将显示为:

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中%x%y是另外两个名为x和y的节点。值得注意的是,字符串torch.op.aten.add.Tensor代表存储在target字段中的可调用对象,而不仅仅是其字符串名称。

此文本格式的最后一行是:

return [add]

这是一个op_name为output的节点,表示我们正在返回这个单一的元素。

call_function#

一个call_function节点表示对一个操作的调用。

定义

  • 函数式: 我们说一个可调用对象是“函数式”的,如果它满足以下所有要求:

    • 非修改性:该操作不会修改其输入的(对于张量,这包括元数据和数据)。

    • 无副作用:该操作不会修改从外部可见的状态,例如更改模块参数的值。

  • 操作: 是一个具有预定义模式的函数式可调用对象。这类操作的例子包括函数式ATen操作。

在FX中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

与普通FX的call_function的区别

  1. 在FX图中,call_function可以引用任何可调用对象,而在Export IR中,我们将其限制为仅限于选定的ATen操作、自定义操作和控制流操作。

  2. 在Export IR中,常量参数将被嵌入到图中。

  3. 在FX图中,get_attr节点可以表示读取图模块中存储的任何属性。然而,在Export IR中,这仅限于读取子模块,因为所有参数/缓冲区都将作为输入传递给图模块。

元数据#

Node.meta是一个附加到每个FX节点上的字典。然而,FX规范并未指定可以或将有哪些元数据。Export IR提供了更强的契约,特别是所有call_function节点将保证具有并且仅具有以下元数据字段:

  • node.meta["stack_trace"]是一个包含Python堆栈跟踪的字符串,它引用了原始的Python源代码。一个堆栈跟踪的例子如下:

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"]描述了运行操作的输出。它可以是<symint><FakeTensor>List[Union[FakeTensor, SymInt]]None类型的。

  • node.meta["nn_module_stack"]描述了该节点来自的torch.nn.Module的“堆栈跟踪”,如果它来自torch.nn.Module调用。例如,如果一个包含addmm操作的节点是从torch.nn.Linear模块中调用,而该模块又在torch.nn.Sequential模块内,那么nn_module_stack将看起来像这样:

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"]包含该节点在分解之前被调用的torch函数或叶torch.nn.Module类。例如,一个包含addmm操作且来自torch.nn.Linear模块调用的节点,将在其source_fn中包含torch.nn.Linear,而一个包含addmm操作且来自torch.nn.functional.Linear模块调用的节点,将在其source_fn中包含torch.nn.functional.Linear

placeholder#

Placeholder代表图的输入。其语义与FX中的完全相同。Placeholder节点必须是图中节点列表的最初N个节点。N可以为零。

在FX中的表示

%name = placeholder[target = name](args = ())

target字段是一个字符串,表示输入的名称。

args,如果不为空,则应为大小为1,表示此输入的默认值。

元数据

Placeholder节点也具有meta[‘val’],类似于call_function节点。在这种情况下,val字段表示图期望接收的此输入参数的形状/数据类型。

output#

输出调用代表函数中的return语句;因此,它会终止当前图。只有一个输出节点,并且它将始终是图的最后一个节点。

在FX中的表示

output[](args = (%something, …))

这与torch.fx中的语义完全相同。args代表要返回的节点。

元数据

输出节点具有与call_function节点相同的元数据。

get_attr#

get_attr节点表示从包含的torch.fx.GraphModule中读取子模块。与torch.fx.symbolic_trace()生成的普通FX图不同,在普通FX图中get_attr节点用于从顶层torch.fx.GraphModule读取参数和缓冲区等属性,而在Export IR中,参数和缓冲区作为输入传递给图模块,并存储在顶层的torch.export.ExportedProgram中。

在FX中的表示

%name = get_attr[target = name](args = ())

示例

考虑以下模型:

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

Graph

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0],读取包含sin操作的子模块true_graph_0

参考文献#

SymInt#

SymInt 是一个对象,它可以是字面整数,也可以是代表整数的符号(在Python中由sympy.Symbol类表示)。当SymInt是符号时,它描述了一个在编译时图中未知类型的整数变量,也就是说,它的值只在运行时才知道。

FakeTensor#

FakeTensor 是一个包含张量元数据的对象。它可以被看作具有以下元数据:

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor的size字段是整数或SymInts的列表。如果存在SymInts,则表示该张量具有动态形状。如果存在整数,则假定该张量具有精确的静态形状。TensorMeta的秩永远不会是动态的。dtype字段表示该节点输出的数据类型。Edge IR中没有隐式类型提升。FakeTensor中没有步幅(strides)。

换句话说:

  • 如果node.target中的操作返回一个Tensor,那么node.meta['val']是一个描述该张量的FakeTensor。

  • 如果node.target中的操作返回一个Tensor的n元组,那么node.meta['val']是一个描述每个张量的FakeTensor的n元组。

  • 如果node.target中的操作返回一个在编译时已知的int/float/scalar,那么node.meta['val']为None。

  • 如果node.target中的操作返回一个在编译时未知的int/float/scalar,那么node.meta['val']是SymInt类型。

例如

  • aten::add返回一个Tensor;因此,其规范将是一个FakeTensor,包含该操作返回的张量的数据类型和尺寸。

  • aten::sym_size返回一个整数;因此,其val将是一个SymInt,因为它的值只在运行时可用。

  • max_pool2d_with_indexes返回一个(Tensor,Tensor)元组;因此,其规范也将是一个FakeTensor对象的2元组,第一个TensorMeta描述返回值的第一项,依此类推。

Python代码

def add_one(x):
  return torch.ops.aten(x, 1)

Graph

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree-able 类型#

我们将一个类型定义为“Pytree-able”,如果它是叶子类型或包含其他Pytree-able类型的容器类型。

注意

Pytree的概念与JAX文档此处中记录的概念相同。

以下类型被定义为**叶子类型**:

类型

定义

张量

torch.Tensor

Scalar

Python中的任何数值类型,包括整数类型、浮点数类型和零维张量。

int

Python int(在C++中绑定为int64_t)

浮点数

Python float(在C++中绑定为double)

布尔值

Python bool

str

Python string

ScalarType

torch.dtype

Layout

torch.layout

MemoryFormat

torch.memory_format

设备

torch.device

以下类型被定义为**容器类型**:

类型

定义

Tuple

Python tuple

List

Python list

Dict

具有标量键的Python dict

NamedTuple

Python namedtuple

Dataclass

必须通过register_dataclass注册

自定义类

使用_register_pytree_node定义的任何自定义类