评价此页

torch.export#

创建于: 2025年6月12日 | 最后更新于: 2025年8月11日

概述#

torch.export.export() 接收一个 torch.nn.Module 并以 Ahead-of-Time (AOT) 的方式生成一个仅表示函数 Tensor 计算的跟踪图,该图随后可以执行不同的输出或序列化。

import torch
from torch.export import export, ExportedProgram

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: ExportedProgram = export(Mod(), args=example_args)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
             # File: /tmp/ipykernel_647/2550508656.py:6 in forward, code: a = torch.sin(x)
            sin: "f32[10, 10]" = torch.ops.aten.sin.default(x);  x = None
            
             # File: /tmp/ipykernel_647/2550508656.py:7 in forward, code: b = torch.cos(y)
            cos: "f32[10, 10]" = torch.ops.aten.cos.default(y);  y = None
            
             # File: /tmp/ipykernel_647/2550508656.py:8 in forward, code: return a + b
            add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos);  sin = cos = None
            return (add,)
            
Graph signature: 
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    
    # outputs
    add: USER_OUTPUT
    
Range constraints: {}

torch.export 生成一个具有以下不变式的清晰的中间表示 (IR)。有关 IR 的更多规范可以在 这里 找到。

  • 可靠性: 保证是对原始程序的可靠表示,并保持原始程序的调用约定。

  • 标准化: 图中没有 Python 语义。原始程序中的子模块被内联,形成一个完全展平的计算图。

  • 图属性: 该图是纯函数式的,意味着它不包含副作用操作,如突变或别名。它不会修改任何中间值、参数或缓冲区。

  • 元数据: 图包含在跟踪过程中捕获的元数据,例如来自用户代码的堆栈跟踪。

底层,torch.export 利用了以下最新技术

  • TorchDynamo (torch._dynamo) 是一个内部 API,它使用名为 Frame Evaluation API 的 CPython 功能来安全地跟踪 PyTorch 图。这提供了大大改进的图捕获体验,只需要进行更少的重写即可完全跟踪 PyTorch 代码。

  • AOT Autograd 提供了一个函数化的 PyTorch 图,并确保图被分解/降低到 ATen 运算符集。

  • Torch FX (torch.fx) 是图的基础表示,允许灵活的基于 Python 的转换。

现有框架#

torch.compile() 也利用了与 torch.export 相同的 PT2 堆栈,但略有不同

  • JIT vs. AOT: torch.compile() 是一个 JIT 编译器,而 torch.export 是一个 AOT 编译器,后者不打算用于在部署之外生成编译后的工件。

  • 部分 vs. 完整图捕获: 当 torch.compile() 遇到模型中无法跟踪的部分时,它会“图中断”并回退到在急切 Python 运行时中运行程序。相比之下,torch.export 旨在获得 PyTorch 模型的完整图表示,因此在遇到无法跟踪的内容时会报错。由于 torch.export 生成了与任何 Python 功能或运行时分离的完整图,因此该图可以被保存、加载并在不同的环境和语言中运行。

  • 可用性权衡: 由于 torch.compile() 能够随时回退到 Python 运行时以处理任何无法跟踪的内容,因此它更加灵活。而 torch.export 则要求用户提供更多信息或重写代码以使其可跟踪。

torch.fx.symbolic_trace() 相比,torch.export 使用 TorchDynamo 进行跟踪,TorchDynamo 在 Python 字节码级别运行,这使其能够跟踪任意 Python 结构,而不受 Python 运算符重载支持的限制。此外,torch.export 会精细地跟踪 Tensor 元数据,以便像 Tensor 形状上的条件不会导致跟踪失败。总的来说,torch.export 应该能处理更多用户程序,并生成更低级别的图(在 torch.ops.aten 运算符级别)。请注意,用户仍然可以使用 torch.fx.symbolic_trace() 作为 torch.export 之前的预处理步骤。

torch.jit.script() 相比,torch.export 不会捕获 Python 控制流或数据结构,除非使用显式的 控制流运算符,但由于其对 Python 字节码的全面覆盖,它支持更多的 Python 语言特性。生成的图更简单,只有直线控制流,除了显式的控制流运算符。

torch.jit.trace() 相比,torch.export 是可靠的:它可以跟踪执行 Tensor 形状上整数计算的代码,并记录所有必要 的侧条件,以确保特定跟踪对其他输入有效。

导出 PyTorch 模型#

主入口点是通过 torch.export.export(),它接收一个 torch.nn.Module 和示例输入,并将计算图捕获到一个 torch.export.ExportedProgram 中。一个示例

import torch
from torch.export import export, ExportedProgram

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)

# To run the exported program, we can use the `module()` method
print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
            conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]);  x = p_conv_weight = p_conv_bias = None
            
             # File: /tmp/ipykernel_647/2848084713.py:16 in forward, code: a.add_(constant)
            add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant);  conv2d = constant = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_);  add_ = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/pooling.py:226 in forward, code: return F.max_pool2d(
            max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]);  relu = None
            return (max_pool2d,)
            
Graph signature: 
    # inputs
    p_conv_weight: PARAMETER target='conv.weight'
    p_conv_bias: PARAMETER target='conv.bias'
    x: USER_INPUT
    constant: USER_INPUT
    
    # outputs
    max_pool2d: USER_OUTPUT
    
Range constraints: {}

tensor([[[[1.3647, 2.2724, 1.9313,  ..., 2.2872, 2.0532, 1.7862],
          [2.3376, 1.7866, 2.0821,  ..., 1.8447, 2.5090, 1.4548],
          [2.1719, 2.0520, 1.6153,  ..., 2.1426, 1.4364, 1.7749],
          ...,
          [2.1165, 2.0080, 1.6747,  ..., 1.8013, 2.3714, 2.1069],
          [1.9574, 1.9642, 1.6017,  ..., 2.2444, 1.7903, 2.1224],
          [1.6735, 1.4687, 1.4179,  ..., 2.1677, 1.4616, 1.9001]],

         [[1.7637, 2.2845, 1.6327,  ..., 1.8567, 1.6745, 1.3376],
          [1.8699, 1.1547, 2.5045,  ..., 1.1387, 2.0210, 1.3825],
          [1.7050, 1.3393, 1.9955,  ..., 1.4296, 1.8792, 1.7073],
          ...,
          [1.5292, 1.9183, 1.8844,  ..., 2.0815, 1.8089, 2.5264],
          [2.0338, 2.3296, 2.1650,  ..., 2.0727, 2.0166, 1.7309],
          [1.8381, 1.6556, 1.9402,  ..., 1.6529, 1.8134, 1.6075]],

         [[1.6673, 2.1074, 1.5976,  ..., 1.7421, 1.7998, 1.6087],
          [2.0269, 1.3379, 1.6679,  ..., 1.6671, 1.6000, 1.9894],
          [2.0480, 1.5340, 1.4017,  ..., 1.7944, 0.9860, 2.3785],
          ...,
          [1.5266, 1.3949, 1.2980,  ..., 1.3569, 1.9492, 1.8062],
          [1.8315, 1.2293, 1.1087,  ..., 1.5446, 1.6492, 1.6620],
          [1.4799, 1.3720, 1.5748,  ..., 1.8854, 1.2940, 1.7422]],

         ...,

         [[1.5734, 2.2576, 1.6242,  ..., 2.2690, 1.5416, 1.6914],
          [1.6577, 1.1605, 1.3565,  ..., 0.8677, 1.1838, 1.7662],
          [2.0769, 1.6546, 1.6169,  ..., 2.1301, 1.3892, 1.7564],
          ...,
          [1.4403, 2.0147, 2.0693,  ..., 1.8359, 1.3394, 1.6654],
          [1.6513, 2.2535, 2.0069,  ..., 1.0825, 1.6039, 1.3635],
          [1.3937, 1.0050, 2.8575,  ..., 1.9308, 1.4201, 1.4665]],

         [[1.8365, 1.3995, 1.8011,  ..., 1.7474, 2.0621, 1.7876],
          [1.7984, 2.0523, 1.4683,  ..., 1.7031, 2.4383, 1.2690],
          [1.6586, 1.4678, 1.7569,  ..., 1.4851, 1.5530, 2.1754],
          ...,
          [1.3299, 1.2796, 1.9345,  ..., 1.8214, 2.7972, 2.0472],
          [1.3732, 1.3926, 1.5657,  ..., 1.4157, 2.2771, 1.9893],
          [1.4669, 1.7627, 1.6258,  ..., 1.5350, 2.0609, 2.1192]],

         [[1.8094, 2.4813, 1.9160,  ..., 2.0701, 1.2966, 1.0297],
          [2.2580, 1.9149, 1.8033,  ..., 1.4026, 1.9687, 1.5859],
          [1.3034, 1.1883, 1.5391,  ..., 1.5249, 1.5463, 2.0138],
          ...,
          [2.2308, 1.7266, 1.8078,  ..., 1.7636, 2.2842, 1.5676],
          [1.7690, 1.3066, 1.8903,  ..., 1.8704, 1.5074, 1.5758],
          [2.3337, 1.5704, 1.8364,  ..., 1.4995, 1.8263, 1.8836]]]],
       grad_fn=<MaxPool2DWithIndicesBackward0>)

检查 ExportedProgram,我们可以注意到以下几点

  • torch.fx.Graph 包含原始程序的计算图,以及用于方便调试的原始代码记录。

  • 图仅包含 torch.ops.aten 运算符,这些运算符可以在 此处 找到,以及自定义运算符。

  • 参数(conv 的 weight 和 bias)被提升为图的输入,因此图中没有 `get_attr` 节点,而这些节点在 torch.fx.symbolic_trace() 的结果中存在。

  • torch.export.ExportGraphSignature 模型化了输入和输出签名,同时指定了哪些输入是参数。

  • 图中每个节点生成的 Tensor 的形状和数据类型被记录下来。例如,`conv2d` 节点将生成一个 dtype 为 `torch.float32`,形状为 (1, 16, 256, 256) 的 Tensor。

表达动态性#

默认情况下,torch.export 会在假设所有输入形状都是**静态**的情况下跟踪程序,并将导出的程序专门化到这些维度。这带来的一个后果是,在运行时,该程序将无法处理形状不同的输入,即使它们在急切模式下是有效的。

一个例子

import torch
import traceback as tb

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

ep = torch.export.export(M(), example_args)
print(ep)

example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
try:
    ep.module()(*example_args2)  # fails
except Exception:
    tb.print_exc()
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[32, 64]", x2: "f32[32, 128]"):
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[32, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias);  x1 = p_branch1_0_weight = p_branch1_0_bias = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu: "f32[32, 32]" = torch.ops.aten.relu.default(linear);  linear = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[32, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias);  x2 = p_branch2_0_weight = p_branch2_0_bias = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu_1: "f32[32, 64]" = torch.ops.aten.relu.default(linear_1);  linear_1 = None
            
             # File: /tmp/ipykernel_647/1522925308.py:19 in forward, code: return (out1 + self.buffer, out2)
            add: "f32[32, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer);  relu = c_buffer = None
            return (add, relu_1)
            
Graph signature: 
    # inputs
    p_branch1_0_weight: PARAMETER target='branch1.0.weight'
    p_branch1_0_bias: PARAMETER target='branch1.0.bias'
    p_branch2_0_weight: PARAMETER target='branch2.0.weight'
    p_branch2_0_bias: PARAMETER target='branch2.0.bias'
    c_buffer: CONSTANT_TENSOR target='buffer'
    x1: USER_INPUT
    x2: USER_INPUT
    
    # outputs
    add: USER_OUTPUT
    relu_1: USER_OUTPUT
    
Range constraints: {}
Traceback (most recent call last):
  File "/tmp/ipykernel_647/1522925308.py", line 28, in <module>
    ep.module()(*example_args2)  # fails
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
    return inner()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1829, in inner
    result = forward_call(*args, **kwargs)
  File "<eval_with_key>.25", line 11, in forward
    _guards_fn = self._guards_fn(x1, x2);  _guards_fn = None
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
    return func(*args, **kwargs)
  File "<string>", line 3, in _
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/__init__.py", line 2185, in _assert
    assert condition, message
AssertionError: Guard failed: x1.size()[0] == 32

然而,某些维度,例如批次维度,可以是动态的,并且每次运行都可能不同。此类维度必须通过使用 torch.export.Dim() API 来创建,并通过 dynamic_shapes 参数传递给 torch.export.export()

import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

ep = torch.export.export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(ep)

example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
ep.module()(*example_args2)  # success
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s24, 64]", x2: "f32[s24, 128]"):
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s24, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias);  x1 = p_branch1_0_weight = p_branch1_0_bias = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu: "f32[s24, 32]" = torch.ops.aten.relu.default(linear);  linear = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s24, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias);  x2 = p_branch2_0_weight = p_branch2_0_bias = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu_1: "f32[s24, 64]" = torch.ops.aten.relu.default(linear_1);  linear_1 = None
            
             # File: /tmp/ipykernel_647/3456136871.py:18 in forward, code: return (out1 + self.buffer, out2)
            add: "f32[s24, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer);  relu = c_buffer = None
            return (add, relu_1)
            
Graph signature: 
    # inputs
    p_branch1_0_weight: PARAMETER target='branch1.0.weight'
    p_branch1_0_bias: PARAMETER target='branch1.0.bias'
    p_branch2_0_weight: PARAMETER target='branch2.0.weight'
    p_branch2_0_bias: PARAMETER target='branch2.0.bias'
    c_buffer: CONSTANT_TENSOR target='buffer'
    x1: USER_INPUT
    x2: USER_INPUT
    
    # outputs
    add: USER_OUTPUT
    relu_1: USER_OUTPUT
    
Range constraints: {s24: VR[0, int_oo]}
(tensor([[1.0000, 1.8873, 1.2537,  ..., 1.8441, 1.5857, 1.2611],
         [1.0000, 1.0000, 1.0598,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 2.3600,  ..., 1.8214, 1.0710, 1.0000],
         ...,
         [1.0000, 1.0516, 1.1907,  ..., 1.4364, 1.0000, 1.0000],
         [1.0000, 1.1808, 1.0000,  ..., 1.0000, 1.0000, 1.3670],
         [1.5627, 2.3273, 1.0000,  ..., 1.0000, 1.0000, 1.6566]],
        grad_fn=<AddBackward0>),
 tensor([[0.0000, 0.0000, 0.0000,  ..., 0.1734, 0.0999, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4386, 0.7500],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [1.1618, 0.0000, 0.0000,  ..., 0.0000, 0.1554, 0.0000],
         [0.0000, 0.4619, 0.0904,  ..., 0.0000, 0.0792, 0.5539],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
        grad_fn=<ReluBackward0>))

一些额外的注意事项

  • 通过 torch.export.Dim() API 和 dynamic_shapes 参数,我们指定了每个输入的第一个维度是动态的。查看输入 `x1` 和 `x2`,它们具有符号形状 (s0, 64)(s0, 128),而不是我们作为示例输入传入的形状为 (32, 64)(32, 128) 的 Tensor。`s0` 是一个符号,代表该维度可以是一系列值。

  • exported_program.range_constraints 描述了图中出现的每个符号的范围。在这种情况下,我们看到 `s0` 的范围是 [0, int_oo]。出于难以在此处解释的技术原因,它们假定不为 0 或 1。这不是 bug,也不一定意味着导出的程序不适用于维度 0 或 1。请参阅 0/1 特化问题 以深入讨论此主题。

在此示例中,我们使用了 Dim("batch") 来创建一个动态维度。这是指定动态性的最明确的方法。我们也可以使用 Dim.DYNAMICDim.AUTO 来指定动态性。我们将在下一节中介绍这两种方法。

命名维度#

对于用 Dim("name") 指定的每个维度,我们将分配一个符号形状。使用相同名称指定 Dim 将导致生成相同的符号。这允许用户指定为每个输入维度分配了哪些符号。

batch = Dim("batch")
dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}}

对于每个 Dim,我们可以指定最小值和最大值。我们也允许在单变量线性表达式中指定 Dim 之间的关系:A * dim + B。这允许用户为动态维度指定更复杂的约束,如整数可除性。这些功能使用户能够对生成的 ExportedProgram 的动态行为施加明确的限制。

dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
    "x": (dx, None),
    "y": (2 * dx, dh),
}

但是,如果在跟踪过程中发出与给定关系或静态/动态规范冲突的保护,将引发 ConstraintViolationErrors。例如,在上述规范中,断言如下

  • x.shape[0] 的范围是 [4, 256],并且与 y.shape[0] 的关系是 y.shape[0] == 2 * x.shape[0]

  • x.shape[1] 是静态的。

  • y.shape[1] 的范围是 [0, 512],并且与任何其他维度无关。

如果在跟踪过程中发现任何这些断言不正确(例如,`x.shape[0]` 是静态的,或者 `y.shape[1]` 的范围更小,或者 `y.shape[0] != 2 * x.shape[0]`),则将引发 ConstraintViolationError,用户需要更改其 dynamic_shapes 规范。

维度提示#

而不是使用 Dim("name") 显式指定动态性,我们可以让 torch.export 使用 Dim.DYNAMIC 来推断动态值的范围和关系。当您不确定动态值具体动态到什么程度时,这也是一种更方便的指定动态性的方法。

dynamic_shapes = {
    "x": (Dim.DYNAMIC, None),
    "y": (Dim.DYNAMIC, Dim.DYNAMIC),
}

我们还可以为 Dim.DYNAMIC 指定 min/max 值,这些值将作为导出的提示。但如果在跟踪过程中导出发现范围不同,它将自动更新范围而不会引发错误。我们也无法指定动态值之间的关系。相反,这将由导出推断,并通过检查图中的断言暴露给用户。在这种指定动态性的方法中,只有当推断出的值为**静态**时,才会引发 ConstraintViolationErrors

指定动态性的一个更方便的方法是使用 Dim.AUTO,它的行为类似于 Dim.DYNAMIC,但如果推断出维度是静态的,则不会引发错误。当您对动态值的范围一无所知,并希望以“尽力而为”的动态方式导出程序时,这很有用。

ShapesCollection#

在通过 dynamic_shapes 指定哪些输入是动态的时,我们必须指定每个输入的动态性。例如,给定以下输入

args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

我们需要指定 `tensor_x`、`tensor_y` 和 `tensor_z` 的动态性以及动态形状

# With named-Dims
dim = torch.export.Dim(...)
dynamic_shapes = {"x": {0: dim, 1: dim + 1}, "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

然而,这特别复杂,因为我们需要以与输入参数相同的嵌套输入结构来指定 `dynamic_shapes` 规范。相反,一种更简单的指定动态形状的方法是使用辅助工具 torch.export.ShapesCollection,其中我们不必指定每个输入的动态性,而是可以直接分配哪些输入维度是动态的。

import torch

class M(torch.nn.Module):
    def forward(self, inp):
        x = inp["x"] * 1
        y = inp["others"][0] * 2
        z = inp["others"][1] * 3
        return x, y, z

tensor_x = torch.randn(3, 4, 8)
tensor_y = torch.randn(6)
tensor_z = torch.randn(6)
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

dim = torch.export.Dim("dim")
sc = torch.export.ShapesCollection()
sc[tensor_x] = (dim, dim + 1, 8)
sc[tensor_y] = {0: dim * 2}

print(sc.dynamic_shapes(M(), (args,)))
ep = torch.export.export(M(), (args,), dynamic_shapes=sc)
print(ep)
{'inp': {'x': (Dim('dim', min=0), dim + 1, 8), 'others': [{0: 2*dim}, None]}}
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, inp_x: "f32[s96, s96 + 1, 8]", inp_others_0: "f32[2*s96]", inp_others_1: "f32[6]"):
             # File: /tmp/ipykernel_647/1070110726.py:5 in forward, code: x = inp["x"] * 1
            mul: "f32[s96, s96 + 1, 8]" = torch.ops.aten.mul.Tensor(inp_x, 1);  inp_x = None
            
             # File: /tmp/ipykernel_647/1070110726.py:6 in forward, code: y = inp["others"][0] * 2
            mul_1: "f32[2*s96]" = torch.ops.aten.mul.Tensor(inp_others_0, 2);  inp_others_0 = None
            
             # File: /tmp/ipykernel_647/1070110726.py:7 in forward, code: z = inp["others"][1] * 3
            mul_2: "f32[6]" = torch.ops.aten.mul.Tensor(inp_others_1, 3);  inp_others_1 = None
            return (mul, mul_1, mul_2)
            
Graph signature: 
    # inputs
    inp_x: USER_INPUT
    inp_others_0: USER_INPUT
    inp_others_1: USER_INPUT
    
    # outputs
    mul: USER_OUTPUT
    mul_1: USER_OUTPUT
    mul_2: USER_OUTPUT
    
Range constraints: {s96: VR[0, int_oo], s96 + 1: VR[1, int_oo], 2*s96: VR[0, int_oo]}

AdditionalInputs#

如果您不知道输入的动态性有多大,但有一组丰富的测试或性能分析数据,可以提供对模型代表性输入的合理了解,您可以使用 torch.export.AdditionalInputs 代替 dynamic_shapes。您可以指定用于跟踪程序的所有可能的输入,并且 AdditionalInputs 将根据输入形状的变化来推断哪些输入是动态的。

示例

import dataclasses
import torch
import torch.utils._pytree as pytree

@dataclasses.dataclass
class D:
    b: bool
    i: int
    f: float
    t: torch.Tensor

pytree.register_dataclass(D)

class M(torch.nn.Module):
    def forward(self, d: D):
        return d.i + d.f + d.t

input1 = (D(True, 3, 3.0, torch.ones(3)),)
input2 = (D(True, 4, 3.0, torch.ones(4)),)
ai = torch.export.AdditionalInputs()
ai.add(input1)
ai.add(input2)

print(ai.dynamic_shapes(M(), input1))
ep = torch.export.export(M(), input1, dynamic_shapes=ai)
print(ep)
{'d': [None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True),)]}
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, d_b, d_i: "Sym(s37)", d_f, d_t: "f32[s99]"):
             # File: /tmp/ipykernel_647/829931439.py:16 in forward, code: return d.i + d.f + d.t
            sym_float: "Sym(ToFloat(s37))" = torch.sym_float(d_i);  d_i = None
            add: "Sym(ToFloat(s37) + 3.0)" = sym_float + 3.0;  sym_float = None
            add_1: "f32[s99]" = torch.ops.aten.add.Tensor(d_t, add);  d_t = add = None
            return (add_1,)
            
Graph signature: 
    # inputs
    d_b: USER_INPUT
    d_i: USER_INPUT
    d_f: USER_INPUT
    d_t: USER_INPUT
    
    # outputs
    add_1: USER_OUTPUT
    
Range constraints: {s37: VR[0, int_oo], s99: VR[2, int_oo]}

序列化#

要保存 ExportedProgram,用户可以使用 torch.export.save()torch.export.load() API。生成的文件是一个具有特定结构的 zip 文件。结构的详细信息在 PT2 Archive Spec 中定义。

一个例子

import torch

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), (torch.randn(5),))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

导出 IR,分解#

torch.export 生成的图返回一个仅包含 ATen 运算符 的图,ATen 运算符是 PyTorch 中的基本计算单元。由于有超过 3000 个 ATen 运算符,导出提供了一种根据某些特征缩小图中使用 的运算符集的方法,从而创建不同的 IR。

默认情况下,导出生成最通用的 IR,其中包含所有 ATen 运算符,包括功能性和非功能性运算符。功能性运算符是没有突变或别名的输入运算符。您可以在 此处 找到所有 ATen 运算符的列表,并且可以通过检查 `op._schema.is_mutable` 来检查运算符是否为功能性的。

此通用 IR 可用于在急切 PyTorch Autograd 中进行训练。

import torch

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = add_ = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True);  conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        return (batch_norm,)
        

然而,如果您想将 IR 用于推理,或减少使用的运算符数量,您可以通过 `ExportedProgram.run_decompositions()` API 将图进行降低。此方法将 ATen 运算符分解为分解表中指定的运算符,并将图进行函数化。

通过指定一个空集,我们只执行函数化,而不进行任何额外的分解。这会生成一个包含约 2000 个运算符(而不是上述 3000 个运算符)的 IR,这对于推理场景非常理想。

import torch

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
    ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05);  conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        return (getitem_3, getitem_4, add, getitem)
        

我们可以看到,以前的就地运算符 `torch.ops.aten.add_.default` 现在已被替换为 `torch.ops.aten.add.default`,这是一个功能性运算符。

我们还可以将此导出的程序进一步降低到仅包含 核心 ATen 运算符集 的运算符集,这是一个约 180 个运算符的集合。此 IR 最适合不想重新实现所有 ATen 运算符的后端。

import torch

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
    core_aten_ir = ep_for_training.run_decompositions(decomp_table=None)
print(core_aten_ir.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  x = p_conv_weight = p_conv_bias = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05);  convolution = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        return (getitem_3, getitem_4, add, getitem)
        

我们现在看到 `torch.ops.aten.conv2d.default` 已被分解为 `torch.ops.aten.convolution.default`。这是因为 `convolution` 是一个更“核心”的运算符,因为 `conv1d` 和 `conv2d` 等运算可以使用相同的 op 来实现。

我们也可以指定自己的分解行为

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))

my_decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  x = p_conv_weight = p_conv_bias = None
        mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2);  convolution = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05);  mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        return (getitem_3, getitem_4, add, getitem)
        

请注意,`torch.ops.aten.conv2d.default` 没有被分解为 `torch.ops.aten.convolution.default`,而是被分解为 `torch.ops.aten.convolution.default` 和 `torch.ops.aten.mul.Tensor`,这符合我们的自定义分解规则。

torch.export 的局限性#

由于 torch.export 是一个从 PyTorch 程序捕获计算图的一次性过程,它最终可能会遇到程序中无法跟踪的部分,因为几乎不可能支持跟踪所有 PyTorch 和 Python 功能。在 torch.compile 的情况下,不支持的操作会导致“图中断”,并且不支持的操作将使用默认的 Python 评估来运行。相比之下,torch.export 会要求用户提供额外信息或重写部分代码以使其可跟踪。

Draft-export 是一个很好的资源,用于列出跟踪程序时会遇到的图中断,以及解决这些错误的额外调试信息。

ExportDB 也是一个很好的资源,用于了解支持和不支持的程序类型,以及重写程序以使其可跟踪的方法。

TorchDynamo 不支持#

当使用 `strict=True` 的 torch.export 时,它将使用 TorchDynamo 在 Python 字节码级别评估程序以将程序跟踪到图。与以前的跟踪框架相比,需要更少的重写即可使程序可跟踪,但仍会存在一些不支持的 Python 功能。为了绕过处理图中断的方法是使用 非严格导出,通过将 `strict` 标志更改为 `strict=False`。

数据/形状依赖的控制流#

当形状未专门化时,数据依赖的控制流(`if x.shape[0] > 2`)也可能遇到图中断,因为跟踪编译器无法处理,除非生成组合爆炸式路径的代码。在这种情况下,用户需要使用特殊的控制流运算符重写其代码。目前,我们支持 torch.cond 来表示 if-else 类似的控制流(更多内容即将推出!)。

您还可以参考此 教程 以了解更多解决数据依赖错误的方法。

运算符缺少 Fake/Meta 内核#

在跟踪时,需要为所有运算符提供 FakeTensor 内核(也称为 meta 内核)。这用于推理运算符的输入/输出形状。

有关更多详细信息,请参阅此 教程

如果您的模型不幸使用了没有 FakeTensor 内核实现的 ATen 运算符,请提交一个 issue。