评价此页

torch.export#

创建于: 2025年6月12日 | 最后更新于: 2025年6月12日

警告

此功能为原型,正在积极开发中,未来将会有重大更改。

概述#

torch.export.export() 接收一个 torch.nn.Module 并生成一个跟踪图,该图仅表示函数中的 Tensor 计算,以 AOT(Ahead-of-Time,提前)方式进行,随后可以以不同的输出执行或序列化。

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: torch.export.ExportedProgram = export(
    Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
            # code: a = torch.sin(x)
            sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)

            # code: b = torch.cos(y)
            cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)

            # code: return a + b
            add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
            return (add,)

    Graph signature:
        ExportGraphSignature(
            input_specs=[
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='x'),
                    target=None,
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='y'),
                    target=None,
                    persistent=None
                )
            ],
            output_specs=[
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,
                    arg=TensorArgument(name='add'),
                    target=None
                )
            ]
        )
    Range constraints: {}

torch.export 生成一个干净的中间表示(IR),具有以下不变量。有关 IR 的更多规范,请参阅此处

  • 健全性:保证对原始程序进行健全表示,并保持与原始程序相同的调用约定。

  • 规范化:图中没有 Python 语义。原始程序中的子模块被内联以形成一个完全扁平化的计算图。

  • 图属性:该图是纯函数式的,这意味着它不包含具有副作用的操作,例如变异或别名。它不修改任何中间值、参数或缓冲区。

  • 元数据:该图包含在跟踪过程中捕获的元数据,例如来自用户代码的堆栈跟踪。

在底层,torch.export 利用以下最新技术:

  • TorchDynamo (torch._dynamo) 是一个内部 API,它使用 CPython 的帧评估 API 安全地跟踪 PyTorch 图。这极大地改进了图捕获体验,大大减少了完全跟踪 PyTorch 代码所需的重写次数。

  • AOT Autograd 提供了一个功能化的 PyTorch 图,并确保该图被分解/降低到 ATen 操作符集。

  • Torch FX (torch.fx) 是图的底层表示,允许灵活的基于 Python 的转换。

现有框架#

torch.compile() 也使用与 torch.export 相同的 PT2 堆栈,但略有不同:

  • JIT 与 AOTtorch.compile() 是一个 JIT 编译器,不打算用于在部署之外生成编译后的工件。

  • 部分图捕获与完整图捕获:当 torch.compile() 遇到模型中无法跟踪的部分时,它将“图中断”并回退到在 Eager Python 运行时中运行程序。相比之下,torch.export 旨在获取 PyTorch 模型的完整图表示,因此当遇到无法跟踪的内容时,它将报错。由于 torch.export 生成的完整图与任何 Python 功能或运行时分离,因此该图可以保存、加载并在不同的环境和语言中运行。

  • 可用性权衡:由于 torch.compile() 能够在遇到无法跟踪的内容时回退到 Python 运行时,因此它更加灵活。torch.export 则会要求用户提供更多信息或重写代码以使其可跟踪。

torch.fx.symbolic_trace() 相比,torch.export 使用 TorchDynamo 进行跟踪,TorchDynamo 在 Python 字节码级别运行,使其能够跟踪不受 Python 运算符重载支持限制的任意 Python 构造。此外,torch.export 能够细粒度地跟踪 Tensor 元数据,因此对 Tensor 形状等条件的判断不会导致跟踪失败。通常,torch.export 有望在更多用户程序上运行,并生成更底层的图(在 torch.ops.aten 运算符级别)。请注意,用户仍然可以在 torch.export 之前使用 torch.fx.symbolic_trace() 作为预处理步骤。

torch.jit.script() 相比,torch.export 不捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为它更容易对 Python 字节码进行全面覆盖)。生成的图更简单,只有直线控制流(显式控制流运算符除外)。

torch.jit.trace() 相比,torch.export 是健全的:它能够跟踪对大小执行整数计算的代码,并记录所有必要的附加条件,以表明特定跟踪对于其他输入是有效的。

导出 PyTorch 模型#

示例#

主要入口点是 torch.export.export(),它接受一个可调用对象(torch.nn.Module、函数或方法)和示例输入,并将计算图捕获到 torch.export.ExportedProgram 中。例如:

import torch
from torch.export import export

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
            # code: a = self.conv(x)
            conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])

            # code: a.add_(constant)
            add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)

            # code: return self.maxpool(self.relu(a))
            relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
            max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
            return (max_pool2d,)

Graph signature:
    ExportGraphSignature(
        input_specs=[
            InputSpec(
                kind=<InputKind.PARAMETER: 2>,
                arg=TensorArgument(name='p_conv_weight'),
                target='conv.weight',
                persistent=None
            ),
            InputSpec(
                kind=<InputKind.PARAMETER: 2>,
                arg=TensorArgument(name='p_conv_bias'),
                target='conv.bias',
                persistent=None
            ),
            InputSpec(
                kind=<InputKind.USER_INPUT: 1>,
                arg=TensorArgument(name='x'),
                target=None,
                persistent=None
            ),
            InputSpec(
                kind=<InputKind.USER_INPUT: 1>,
                arg=TensorArgument(name='constant'),
                target=None,
                persistent=None
            )
        ],
        output_specs=[
            OutputSpec(
                kind=<OutputKind.USER_OUTPUT: 1>,
                arg=TensorArgument(name='max_pool2d'),
                target=None
            )
        ]
    )
Range constraints: {}

检查 ExportedProgram,我们可以注意到以下几点:

  • torch.fx.Graph 包含原始程序的计算图,以及原始代码的记录,便于调试。

  • 该图仅包含 此处 找到的 torch.ops.aten 运算符和自定义运算符,并且是完全功能的,不包含任何就地运算符,例如 torch.add_

  • 参数(卷积的权重和偏差)被提升为图的输入,导致图中没有 get_attr 节点,而这些节点在 torch.fx.symbolic_trace() 的结果中是存在的。

  • torch.export.ExportGraphSignature 建模输入和输出签名,并指定哪些输入是参数。

  • 图中每个节点产生的张量的最终形状和数据类型已注明。例如,convolution 节点将产生一个数据类型为 torch.float32,形状为 (1, 16, 256, 256) 的张量。

非严格导出#

在 PyTorch 2.3 中,我们引入了一种新的跟踪模式,称为 **非严格模式**。它仍在加强中,因此如果您遇到任何问题,请在 Github 上提交并带有“oncall: export”标签。

在 *非严格模式* 下,我们使用 Python 解释器跟踪程序。您的代码将完全按照其在急切模式下的方式执行;唯一的区别是所有 Tensor 对象都将被 ProxyTensors 替换,后者将记录它们的所有操作到图中。

在 *严格* 模式下(目前是默认模式),我们首先使用 TorchDynamo(一个字节码分析引擎)跟踪程序。TorchDynamo 实际上并不执行您的 Python 代码。相反,它对其进行符号分析并根据结果构建一个图。此分析允许 torch.export 提供更强的安全性保证,但并非所有 Python 代码都受支持。

您可能希望使用非严格模式的一个示例是,如果您遇到无法轻易解决的 TorchDynamo 不支持的功能,并且您知道该 Python 代码并非计算所必需的。例如:

import contextlib
import torch

class ContextManager():
    def __init__(self):
        self.count = 0
    def __enter__(self):
        self.count += 1
    def __exit__(self, exc_type, exc_value, traceback):
        self.count -= 1

class M(torch.nn.Module):
    def forward(self, x):
        with ContextManager():
            return x.sin() + x.cos()

export(M(), (torch.ones(3, 3),), strict=False)  # Non-strict traces successfully
export(M(), (torch.ones(3, 3),))  # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager

在此示例中,第一次使用非严格模式(通过 strict=False 标志)的调用成功跟踪,而第二次使用严格模式(默认)的调用导致失败,其中 TorchDynamo 无法支持上下文管理器。一个选项是重写代码(请参阅 torch.export 的限制),但考虑到上下文管理器不影响模型中的张量计算,我们可以选择非严格模式的结果。

用于训练和推理的导出#

在 PyTorch 2.5 中,我们引入了一个新的 API,名为 export_for_training()。它仍在加强中,因此如果您遇到任何问题,请在 Github 上提交并带有“oncall: export”标签。

在此 API 中,我们生成包含所有 ATen 运算符(包括功能性和非功能性)的最通用 IR,可用于在急切的 PyTorch Autograd 中进行训练。此 API 适用于急切训练用例,例如 PT2 量化,并且很快将成为 torch.export.export 的默认 IR。要进一步了解此更改背后的动机,请参阅 https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206

当此 API 与 run_decompositions() 结合使用时,您应该能够获得具有任何所需分解行为的推理 IR。

为了展示一些例子

class ConvBatchnorm(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

mod = ConvBatchnorm()
inp = torch.randn(1, 1, 3, 3)

ep_for_training = torch.export.export_for_training(mod, (inp,))
print(ep_for_training)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
            conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
            add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1)
            batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True)
            return (batch_norm,)

从上面的输出中,您可以看到 export_for_training() 产生的 ExportedProgram 与 export() 几乎相同,只是图中的运算符不同。您可以看到我们以最通用的形式捕获了 batch_norm。此操作是非功能性的,在运行推理时将被转换为不同的操作。

您还可以通过 run_decompositions() 和任意自定义从该 IR 转换为推理 IR。

# Lower to core aten inference IR, but keep conv2d
decomp_table = torch.export.default_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
ep_for_inference = ep_for_training.run_decompositions(decomp_table)

print(ep_for_inference)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
            conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
            add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
            _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
            getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
            getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
            getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]
            return (getitem_3, getitem_4, add, getitem)

在这里,您可以看到我们保留了 IR 中的 conv2d 操作,同时分解了其余部分。现在,该 IR 是一个功能性 IR,包含核心 Aten 运算符,除了 conv2d

您可以通过直接注册您选择的分解行为来执行更多自定义。

您可以通过直接注册自定义分解行为来执行更多自定义

# Lower to core aten inference IR, but customize conv2d
decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function
ep_for_inference = ep_for_training.run_decompositions(decomp_table)

print(ep_for_inference)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
            convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
            mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2)
            add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
            _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
            getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
            getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
            getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];
            return (getitem_3, getitem_4, add, getitem)

表达动态性#

默认情况下,torch.export 将假定所有输入形状都是 **静态的**,并专门化导出的程序以适应这些维度来跟踪程序。但是,某些维度(例如批处理维度)可以是动态的,并且在每次运行之间都会有所不同。必须使用 torch.export.Dim() API 创建这些维度,并通过 dynamic_shapes 参数将其传递给 torch.export.export() 来指定这些维度。一个例子:

import torch
from torch.export import Dim, export

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
    def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):

         # code: out1 = self.branch1(x1)
        linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
        relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)

         # code: out2 = self.branch2(x2)
        linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
        relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)

         # code: return (out1 + self.buffer, out2)
        add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
        return (add, relu_1)

Range constraints: {s0: VR[0, int_oo]}

其他需要注意的事项:

  • 通过 torch.export.Dim() API 和 dynamic_shapes 参数,我们指定了每个输入的第一维度是动态的。查看输入 x1x2,它们的符号形状是 (s0, 64) 和 (s0, 128),而不是我们作为示例输入传入的形状为 (32, 64) 和 (32, 128) 的张量。s0 是一个符号,表示此维度可以是值的范围。

  • exported_program.range_constraints 描述了图中出现的每个符号的范围。在这种情况下,我们看到 s0 的范围是 [0, int_oo]。由于此处难以解释的技术原因,它们被假定不为 0 或 1。这不是一个错误,并且不一定意味着导出的程序对于维度 0 或 1 不起作用。有关此主题的深入讨论,请参阅 0/1 特殊化问题

我们还可以指定输入形状之间更具表达力的关系,例如一对形状可能相差一、一个形状可能是另一个形状的两倍,或者一个形状是偶数。一个例子:

class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1

exported_program = torch.export.export(
    M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"):
        # code: return x + y[1:]
        slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807)
        add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1)
        return (add,)

Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}

一些注意事项:

  • 通过为第一个输入指定 {0: dimx},我们看到第一个输入的最终形状现在是动态的,为 [s0]。现在通过为第二个输入指定 {0: dimy},我们看到第二个输入的最终形状也是动态的。但是,因为我们表达了 dimy = dimx + 1,所以 y 的形状没有包含新符号,而是与 x 中使用的相同符号 s0 表示。我们可以看到 dimy = dimx + 1 的关系通过 s0 + 1 显示。

  • 查看范围约束,我们看到 s0 的范围是 [3, 6],这是最初指定的,我们可以看到 s0 + 1 的已解析范围是 [4, 7]。

序列化#

为了保存 ExportedProgram,用户可以使用 torch.export.save()torch.export.load() API。一个惯例是使用 .pt2 文件扩展名保存 ExportedProgram

一个例子

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), torch.randn(5))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

专门化#

理解 torch.export 行为的关键概念是 *静态* 值和 *动态* 值之间的区别。

一个 *动态* 值是指在每次运行中可以改变的值。它们的行为类似于 Python 函数的普通参数——您可以为参数传递不同的值,并期望您的函数执行正确的操作。张量 *数据* 被视为动态的。

一个 *静态* 值是在导出时固定的值,在导出的程序执行之间不能更改。当在跟踪期间遇到该值时,导出器会将其视为常量并将其硬编码到图中。

当执行操作(例如 x + y)并且所有输入都是静态的时,操作的输出将直接硬编码到图中,并且操作不会显示(即它将被常量折叠)。

当一个值被硬编码到图中时,我们称该图已经 *专门化* 到该值。

以下值是静态的:

输入张量形状#

默认情况下,torch.export 将跟踪程序,根据输入张量的形状进行专门化,除非通过 torch.exportdynamic_shapes 参数将维度指定为动态。这意味着如果存在依赖于形状的控制流,torch.export 将专门化给定示例输入所采用的分支。例如:

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x):
        if x.shape[0] > 5:
            return x + 1
        else:
            return x - 1

example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[10, 2]"):
        # code: return x + 1
        add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1)
        return (add,)

条件 (x.shape[0] > 5) 不会出现在 ExportedProgram 中,因为示例输入的静态形状为 (10, 2)。由于 torch.export 专门化输入静态形状,因此 else 分支 (x - 1) 将永远不会被执行。为了保留基于跟踪图中张量形状的动态分支行为,需要使用 torch.export.Dim() 来指定输入张量 (x.shape[0]) 的维度为动态,并且需要 重写 源代码。

请注意,作为模块状态一部分的张量(例如参数和缓冲区)始终具有静态形状。

Python 原生类型#

torch.export 也专门处理 Python 原生类型,例如 intfloatboolstr。但是,它们确实有动态变体,例如 SymIntSymFloatSymBool

例如

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, const: int, times: int):
        for i in range(times):
            x = x + const
        return x

example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[2, 2]", const, times):
            # code: x = x + const
            add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1)
            add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1)
            add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1)
            return (add_2,)

由于整数是专门化的,所有 torch.ops.aten.add.Tensor 操作都使用硬编码常量 1 而不是 const 进行计算。如果用户在运行时传递给 const 的值(例如 2)与导出时使用的值(1)不同,这将导致错误。此外,for 循环中使用的 times 迭代器也通过 3 次重复的 torch.ops.aten.add.Tensor 调用“内联”到图中,并且输入 times 从未使用过。

Python 容器#

Python 容器(ListDictNamedTuple 等)被认为是具有静态结构的。

torch.export 的限制#

图中断#

由于 torch.export 是从 PyTorch 程序捕获计算图的一次性过程,因此它最终可能会遇到程序中无法跟踪的部分,因为它几乎不可能支持跟踪所有 PyTorch 和 Python 功能。对于 torch.compile,不支持的操作将导致“图中断”,并且不支持的操作将使用默认的 Python 评估运行。相比之下,torch.export 将要求用户提供额外信息或重写部分代码以使其可跟踪。由于跟踪基于 TorchDynamo,它在 Python 字节码级别进行评估,因此与以前的跟踪框架相比,所需的重写将显著减少。

遇到图中断时,ExportDB 是一个很好的资源,可用于了解支持和不支持的程序类型,以及重写程序以使其可跟踪的方法。

解决图中断的一种方法是使用 非严格导出

数据/形状相关的控制流#

当形状未专门化时,数据相关的控制流 (if x.shape[0] > 2) 也可能导致图中断,因为跟踪编译器无法在不生成大量路径的代码的情况下处理这种情况。在这种情况下,用户需要使用特殊的控制流运算符重写其代码。目前,我们支持 torch.cond 来表达 if-else 类似的控制流(更多即将推出!)。

操作符缺少 Fake/Meta/Abstract 内核#

在跟踪时,所有操作符都需要 FakeTensor 内核(也称为元内核、抽象实现)。这用于推断此操作符的输入/输出形状。

请参阅 torch.library.register_fake() 了解更多详细信息。

不幸的是,如果您的模型使用了尚未实现 FakeTensor 内核的 ATen 操作符,请提交一个问题。

API 参考#

torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=False, preserve_module_call_signature=())[source]#

export() 接受任何 nn.Module 和示例输入,并生成一个跟踪图,该图仅以预先 (AOT) 的方式表示函数的 Tensor 计算,随后可以使用不同的输入执行或序列化。跟踪图 (1) 生成功能 ATen 运算符集(以及任何用户指定的自定义运算符)中的标准化运算符,(2) 消除了所有 Python 控制流和数据结构(某些例外情况),以及 (3) 记录了证明这种规范化和控制流消除对于未来输入是健全的所需的一组形状约束。

健全性保证

在跟踪期间,export() 会注意用户程序和底层 PyTorch 运算符内核所做的形状相关假设。输出 ExportedProgram 仅在这些假设成立时才被视为有效。

跟踪会对输入张量的形状(而非值)做出假设。这些假设必须在图捕获时进行验证,以便 export() 成功。具体来说:

  • 对输入张量静态形状的假设将自动验证,无需额外工作。

  • 输入张量动态形状的假设需要显式指定,方法是使用 Dim() API 构造动态维度,并通过 dynamic_shapes 参数将它们与示例输入关联起来。

如果任何假设无法验证,将引发致命错误。发生这种情况时,错误消息将包含验证假设所需的规范建议修复。例如,export() 可能会建议对动态维度 dim0_x 的定义进行以下修复,假设它出现在与输入 x 关联的形状中,该输入先前定义为 Dim("dim0_x")

dim = Dim("dim0_x", max=5)

这个例子意味着生成的代码要求输入 x 的维度 0 小于或等于 5 才能有效。您可以检查对动态维度定义的建议修复,然后将其逐字复制到您的代码中,而无需更改 export() 调用的 dynamic_shapes 参数。

参数
  • mod (Module) – 我们将跟踪此模块的 forward 方法。

  • args (tuple[Any, ...]) – 示例位置输入。

  • kwargs (Optional[dict[str, Any]]) – 可选示例关键字输入。

  • dynamic_shapes (Optional[Union[dict[str, Any], tuple[Any], list[Any]]]) –

    一个可选参数,其类型应为:1) 从 f 的参数名到其动态形状规范的字典,2) 一个元组,按原始顺序指定每个输入的动态形状规范。如果要指定关键字参数的动态性,则需要按照原始函数签名中定义的顺序传递它们。

    张量参数的动态形状可以指定为 (1) 从动态维度索引到 Dim() 类型的字典,其中不需要在此字典中包含静态维度索引,但如果包含,则应将其映射到 None;或 (2) Dim() 类型或 None 的元组/列表,其中 Dim() 类型对应于动态维度,静态维度用 None 表示。字典或张量元组/列表的参数通过使用映射或包含规范的序列递归指定。

  • strict (bool) – 禁用时(默认),导出函数将通过 Python 运行时跟踪程序,这本身不会验证图中烘焙的一些隐式假设。它仍将验证大多数关键假设,例如形状安全性。启用时(通过设置 strict=True),导出函数将通过 TorchDynamo 跟踪程序,这将确保生成的图的健全性。TorchDynamo 对 Python 功能的覆盖范围有限,因此您可能会遇到更多错误。请注意,切换此参数不会影响生成的 IR 规范不同,无论此处传递什么值,模型都将以相同的方式序列化。

  • preserve_module_call_signature (tuple[str, ...]) – 一个子模块路径列表,其原始调用约定将作为元数据保留。在调用 torch.export.unflatten 时,将使用此元数据来保留模块的原始调用约定。

返回

一个包含跟踪可调用对象的 ExportedProgram

返回类型

ExportedProgram

可接受的输入/输出类型

可接受的输入(用于 argskwargs)和输出类型包括:

  • 原始类型,即 torch.Tensorintfloatboolstr

  • 数据类,但它们必须首先通过调用 register_dataclass() 进行注册。

  • 包含上述所有类型的 dictlisttuplenamedtupleOrderedDict 组成的(嵌套)数据结构。

torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)[source]#

警告

在积极开发中,保存的文件可能无法在 PyTorch 的新版本中使用。

ExportedProgram 保存到文件对象。然后可以使用 Python API torch.export.load 加载。

参数
  • ep (ExportedProgram) – 要保存的导出程序。

  • f (str | os.PathLike[str] | IO[bytes]) – 实现 write 和 flush) 或包含文件名的字符串。

  • extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,这些内容将作为 f 的一部分存储。

  • opset_version (Optional[Dict[str, int]]) – 运算符集名称到此运算符集版本的映射

  • pickle_protocol (int) – 可以指定以覆盖默认协议

示例

import torch
import io


class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10


ep = torch.export.export(MyModule(), (torch.randn(5),))

# Save to file
torch.export.save(ep, "exported_program.pt2")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)

# Save with extra files
extra_files = {"foo.txt": b"bar".decode("utf-8")}
torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)
torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]#

警告

在积极开发中,保存的文件可能无法在 PyTorch 的新版本中使用。

加载先前用 torch.export.save 保存的 ExportedProgram

参数
  • f (str | os.PathLike[str] | IO[bytes]) – 一个文件类对象(必须实现 write 和 flush)或一个包含文件名的字符串。

  • extra_files (Optional[Dict[str, Any]]) – 此映射中给出的额外文件名将被加载,其内容将存储在提供的映射中。

  • expected_opset_version (Optional[Dict[str, int]]) – 运算符集名称到预期运算符集版本的映射

返回

一个 ExportedProgram 对象

返回类型

ExportedProgram

示例

import torch
import io

# Load ExportedProgram from file
ep = torch.export.load("exported_program.pt2")

# Load ExportedProgram from io.BytesIO object
with open("exported_program.pt2", "rb") as f:
    buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Load with extra files.
extra_files = {"foo.txt": ""}  # values will be replaced with data
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
print(extra_files["foo.txt"])
print(ep(torch.randn(5)))
torch.export.draft_export(mod, args, kwargs=None, *, dynamic_shapes=None, preserve_module_call_signature=(), strict=False)[source]#

torch.export.export 的一个版本,旨在一致地生成 ExportedProgram,即使存在潜在的健全性问题,并生成列出发现问题的报告。

返回类型

ExportedProgram

torch.export.register_dataclass(cls, *, serialized_type_name=None)[source]#

将数据类注册为 torch.export.export() 的有效输入/输出类型。

参数
  • cls (type[Any]) – 要注册的数据类类型

  • serialized_type_name (Optional[str]) – 数据类的序列化名称。这是

  • this (如果您想序列化包含的 pytree TreeSpec,则需要) –

  • dataclass.

示例

import torch
from dataclasses import dataclass


@dataclass
class InputDataClass:
    feature: torch.Tensor
    bias: int


@dataclass
class OutputDataClass:
    res: torch.Tensor


torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)


class Mod(torch.nn.Module):
    def forward(self, x: InputDataClass) -> OutputDataClass:
        res = x.feature + x.bias
        return OutputDataClass(res=res)


ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),))
print(ep)
class torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source]#

Dim 类允许用户在其导出的程序中指定动态性。通过用 Dim 标记维度,编译器将维度与包含动态范围的符号整数相关联。

该 API 可以通过两种方式使用:Dim 提示(即自动动态形状:Dim.AUTODim.DYNAMICDim.STATIC),或命名 Dim(即 Dim(“name”, min=1, max=2))。

Dim 提示提供了最低的导出障碍,用户只需指定维度是动态的、静态的,还是留给编译器决定 (Dim.AUTO)。导出过程将自动推断 min/max 范围和维度之间的剩余约束。

示例

class Foo(nn.Module):
    def forward(self, x, y):
        assert x.shape[0] == 4
        assert y.shape[0] >= 16
        return x @ y


x = torch.randn(4, 8)
y = torch.randn(8, 16)
dynamic_shapes = {
    "x": {0: Dim.AUTO, 1: Dim.AUTO},
    "y": {0: Dim.AUTO, 1: Dim.AUTO},
}
ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)

在这里,如果我们用 Dim.DYNAMIC 替换所有 Dim.AUTO 的使用,导出将引发异常,因为 x.shape[0] 被模型限制为静态。

维度之间更复杂的关系也可能由编译器作为运行时断言节点生成代码,例如 (x.shape[0] + y.shape[1]) % 4 == 0,如果运行时输入不满足此类约束,则会引发异常。

您还可以为 Dim 提示指定最小-最大边界,例如 Dim.AUTO(min=16, max=32)Dim.DYNAMIC(max=64),编译器将在这些范围内推断其余约束。如果有效范围完全超出用户指定的范围,将引发异常。

命名 Dim 提供了更严格的动态性指定方式,如果编译器推断的约束与用户规范不匹配,则会引发异常。例如,导出前面的模型时,用户需要以下 dynamic_shapes 参数

s0 = Dim("s0")
s1 = Dim("s1", min=16)
dynamic_shapes = {
    "x": {0: 4, 1: s0},
    "y": {0: s0, 1: s1},
}
ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)

命名维度还允许指定维度之间的关系,最高可达单变量线性关系。例如,以下表示一个维度是另一个维度的倍数加上 4

s0 = Dim("s0")
s1 = 3 * s0 + 4
class torch.export.dynamic_shapes.ShapesCollection[source]#

dynamic_shapes 的构建器。用于将动态形状规范分配给输入中出现的张量。

这在 args() 是嵌套输入结构时特别有用,此时索引输入张量比在 dynamic_shapes() 规范中复制 args() 的结构更容易。

示例

args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

dim = torch.export.Dim(...)
dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[tensor_y] = {0: dim * 2}
# This is equivalent to the following (now auto-generated):
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

为了指定整数的动态性,我们需要首先使用 _IntWrapper 包装整数,以便我们为每个整数提供“唯一标识符”。

示例

args = {"x": tensor_x, "others": [int_x, int_y]}
# Wrap all ints with _IntWrapper
mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args)

dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC

# This is equivalent to the following (now auto-generated):
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)
dynamic_shapes(m, args, kwargs=None)[source]#

根据 args()kwargs() 生成 dynamic_shapes() pytree 结构。

class torch.export.dynamic_shapes.AdditionalInputs[source]#

根据附加输入推断 dynamic_shapes。

这对于部署工程师特别有用,因为他们一方面可能拥有充足的测试或分析数据,可以提供对模型具有代表性输入的公平理解,但另一方面,可能对模型了解不足,无法猜测哪些输入形状应该是动态的。

与原始形状不同的输入形状被视为动态的;相反,与原始形状相同的输入形状被视为静态的。此外,我们验证附加输入对于导出的程序是有效的。这保证了用它们而不是原始输入进行跟踪会生成相同的图。

示例

args0, kwargs0 = ...  # example inputs for export

# other representative inputs that the exported program will run on
dynamic_shapes = torch.export.AdditionalInputs()
dynamic_shapes.add(args1, kwargs1)
...
dynamic_shapes.add(argsN, kwargsN)

torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes)
add(args, kwargs=None)[source]#

附加输入 args()kwargs()

dynamic_shapes(m, args, kwargs=None)[source]#

通过合并原始输入 args()kwargs() 以及每个附加输入 args 和 kwargs 的形状,推断出 dynamic_shapes() pytree 结构。

verify(ep)[source]#

验证导出的程序对每个附加输入都有效。

torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source]#

当使用 dynamic_shapes() 导出时,如果规范与从模型跟踪推断出的约束不匹配,则导出可能会因 ConstraintViolation 错误而失败。错误消息可能会提供建议的修复——可以对 dynamic_shapes() 进行更改以成功导出。

ConstraintViolation 错误消息示例

Suggested fixes:

    dim = Dim('dim', min=3, max=6)  # this just refines the dim's range
    dim = 4  # this specializes to a constant
    dy = dx + 1  # dy was specified as an independent dim, but is actually tied to dx with this relation

这是一个辅助函数,它接受 ConstraintViolation 错误消息和原始 dynamic_shapes() 规范,并返回一个新的 dynamic_shapes() 规范,其中包含建议的修复。

使用示例

try:
    ep = export(mod, args, dynamic_shapes=dynamic_shapes)
except torch._dynamo.exc.UserError as exc:
    new_shapes = refine_dynamic_shapes_from_suggested_fixes(
        exc.msg, dynamic_shapes
    )
    ep = export(mod, args, dynamic_shapes=new_shapes)
返回类型

Union[dict[str, Any], tuple[Any], list[Any]]

class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source]#

来自 export() 的程序包。它包含一个表示张量计算的 torch.fx.Graph、一个包含所有提升参数和缓冲区张量值的 state_dict 以及各种元数据。

您可以像原始通过 export() 跟踪的可调用对象一样调用 ExportedProgram,使用相同的调用约定。

要在图上执行转换,请使用 .module 属性访问 torch.fx.GraphModule。然后,您可以使用 FX 转换 重写图。之后,您可以简单地再次使用 export() 来构造正确的 ExportedProgram。

graph#
graph_signature#
state_dict#
constants#
range_constraints#
module_call_graph#
example_inputs#
module()[source]#

返回一个包含所有参数/缓冲区的自包含 GraphModule。

返回类型

模块

run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)[source]#

对导出的程序运行一组分解并返回一个新的导出程序。默认情况下,我们将运行核心 ATen 分解以获取 核心 ATen 运算符集 中的运算符。

目前,我们不分解联合图。

参数

decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]) – 一个可选参数,指定 Aten 运算符的分解行为 (1) 如果为 None,我们分解为核心 Aten 分解 (2) 如果为空,我们不分解任何运算符

返回类型

ExportedProgram

一些例子

如果你不想分解任何东西

ep = torch.export.export(model, ...)
ep = ep.run_decompositions(decomp_table={})

如果您想获取核心 aten 运算符集,但排除某些运算符,可以这样做:

ep = torch.export.export(model, ...)
decomp_table = torch.export.default_decompositions()
decomp_table[your_op] = your_custom_decomp
ep = ep.run_decompositions(decomp_table=decomp_table)
class torch.export.ExportGraphSignature(input_specs, output_specs)[source]#

ExportGraphSignature 建模了 Export Graph 的输入/输出签名,Export Graph 是一个具有更强不变性保证的 fx.Graph。

导出图是功能性的,并且不会通过 getattr 节点访问图中“状态”如参数或缓冲区。相反,export() 保证参数、缓冲区和常量张量作为输入从图中提升出来。类似地,对缓冲区的任何修改也不包含在图中,相反,修改后的缓冲区更新值被建模为导出图的附加输出。

所有输入和输出的顺序是:

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果导出以下模块:

class CustomModule(nn.Module):
    def __init__(self) -> None:
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer("my_buffer1", torch.tensor(3.0))
        self.register_buffer("my_buffer2", torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (
            x1 + self.my_parameter
        ) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0)  # In-place addition

        return output


mod = CustomModule()
ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))

生成的图是非功能性的

graph():
    %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
    %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
    %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
    return (add_1,)

非功能性图的最终 ExportGraphSignature 将是:

# inputs
p_my_parameter: PARAMETER target='my_parameter'
b_my_buffer1: BUFFER target='my_buffer1' persistent=True
b_my_buffer2: BUFFER target='my_buffer2' persistent=True
x1: USER_INPUT
x2: USER_INPUT

# outputs
add_1: USER_OUTPUT

要获得功能性图,可以使用 run_decompositions()

mod = CustomModule()
ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
ep = ep.run_decompositions()

生成的图是功能性的

graph():
    %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
    %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
    %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
    return (add_2, add_1)

功能图的最终 ExportGraphSignature 将是:

# inputs
p_my_parameter: PARAMETER target='my_parameter'
b_my_buffer1: BUFFER target='my_buffer1' persistent=True
b_my_buffer2: BUFFER target='my_buffer2' persistent=True
x1: USER_INPUT
x2: USER_INPUT

# outputs
add_2: BUFFER_MUTATION target='my_buffer2'
add_1: USER_OUTPUT
class torch.export.ModuleCallSignature(inputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: Optional[list[str]] = None)[source]#
class torch.export.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[source]#
class torch.export.decomp_utils.CustomDecompTable[source]#

这是一个专门用于在导出中处理 `decomp_table` 的自定义字典。我们需要它的原因是,在新版本中,你只能通过从 `decomp_table` 中 **删除** 操作来保留它。这对于自定义操作来说是有问题的,因为我们不知道自定义操作何时会被实际加载到调度器中。因此,我们需要记录自定义操作,直到我们真正需要将其具体化(即运行分解过程时)。

我们保持的不变量是:
  1. 所有 Aten 分解都在初始化时加载。

  2. 当用户从表中读取时,我们会具体化所有操作,以增加调度器选择自定义操作的可能性。

  3. 如果是写入操作,我们不一定具体化。

  4. 我们在导出期间,在调用 `run_decompositions()` 之前,最后一次加载。

copy()[source]#
返回类型

CustomDecompTable

items()[source]#
keys()[source]#
materialize()[source]#
返回类型

dict[torch._ops.OperatorBase, Callable]

pop(*args)[source]#
update(other_dict)[source]#
torch.export.exported_program.default_decompositions()[source]#

这是默认的分解表,其中包含所有 ATEN 运算符到核心 Aten 操作集的分解。请将此 API 与 run_decompositions() 一起使用。

返回类型

CustomDecompTable

class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source]#

ExportGraphSignature 模拟导出图的输入/输出签名,该图是一个具有更强不变量保证的 fx.Graph。

导出图是功能性的,不通过 getattr 节点访问图内的“状态”(如参数或缓冲区)。相反,export() 保证参数、缓冲区和常量张量作为输入被提取到图之外。类似地,对缓冲区的任何修改也不包含在图中,相反,修改后的缓冲区的新值被建模为导出图的额外输出。

所有输入和输出的顺序是:

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果导出以下模块:

class CustomModule(nn.Module):
    def __init__(self) -> None:
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer("my_buffer1", torch.tensor(3.0))
        self.register_buffer("my_buffer2", torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (
            x1 + self.my_parameter
        ) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0)  # In-place addition

        return output


mod = CustomModule()
ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))

生成的图是非功能性的

graph():
    %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
    %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
    %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
    return (add_1,)

非功能性图的最终 ExportGraphSignature 将是:

# inputs
p_my_parameter: PARAMETER target='my_parameter'
b_my_buffer1: BUFFER target='my_buffer1' persistent=True
b_my_buffer2: BUFFER target='my_buffer2' persistent=True
x1: USER_INPUT
x2: USER_INPUT

# outputs
add_1: USER_OUTPUT

要获得功能性图,可以使用 run_decompositions()

mod = CustomModule()
ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
ep = ep.run_decompositions()

生成的图是功能性的

graph():
    %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
    %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
    %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
    return (add_2, add_1)

功能图的最终 ExportGraphSignature 将是:

# inputs
p_my_parameter: PARAMETER target='my_parameter'
b_my_buffer1: BUFFER target='my_buffer1' persistent=True
b_my_buffer2: BUFFER target='my_buffer2' persistent=True
x1: USER_INPUT
x2: USER_INPUT

# outputs
add_2: BUFFER_MUTATION target='my_buffer2'
add_1: USER_OUTPUT
replace_all_uses(old, new)[source]#

替换签名中所有使用旧名称的地方,替换为新名称。

get_replace_hook(replace_inputs=False)[source]#
class torch.export.graph_signature.ExportBackwardSignature(gradients_to_parameters: dict[str, str], gradients_to_user_inputs: dict[str, str], loss_output: str)[source]#
class torch.export.graph_signature.InputKind(value)[source]#

枚举。

class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[source]#
class torch.export.graph_signature.OutputKind(value)[source]#

枚举。

class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[source]#
class torch.export.graph_signature.SymIntArgument(name: str)[source]#
class torch.export.graph_signature.SymBoolArgument(name: str)[source]#
class torch.export.graph_signature.SymFloatArgument(name: str)[source]#
class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[source]#
class torch.export.unflatten.FlatArgsAdapter[source]#

使用 input_spec 调整输入参数以与 target_spec 对齐。

abstract adapt(target_spec, input_spec, input_args, metadata=None, obj=None)[source]#

注意:此适配器可能会修改给定的 input_args_with_path

返回类型

list[Any]

class torch.export.unflatten.InterpreterModule(graph, ty=None)[source]#

一个使用 torch.fx.Interpreter 而非 GraphModule 通常使用的代码生成来执行的模块。这提供了更好的堆栈跟踪信息,并使调试执行更容易。

class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source]#

一个模块,它带有一系列 InterpreterModules,对应于该模块的一系列调用。每次调用该模块都会分派到下一个 InterpreterModule,并在最后一个之后循环返回。

torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]#

解平化一个 ExportedProgram,生成一个与原始即时模块具有相同模块层次结构的模块。如果您尝试将 torch.export 与另一个需要模块层次结构而非 torch.export 通常生成的扁平图的系统一起使用,这会很有用。

注意

解平化模块的 args/kwargs 不一定会与即时模块匹配,因此模块交换(例如 self.submod = new_mod)不一定有效。如果您需要替换模块,则需要设置 torch.export.export()preserve_module_call_signature 参数。

参数
  • module (ExportedProgram) – 要解平化的 ExportedProgram。

  • flat_args_adapter (Optional[FlatArgsAdapter]) – 如果输入 TreeSpec 与导出模块的不匹配,则调整扁平参数。

返回

一个 UnflattenedModule 实例,它与原始即时模块在导出前具有相同的模块层次结构。

返回类型

UnflattenedModule

torch.export.passes.move_to_device_pass(ep, location)[source]#

将导出的程序移动到给定设备。

参数
  • ep (ExportedProgram) – 要移动的导出程序。

  • location (Union[torch.device, str, Dict[str, str]]) – 将导出程序移动到的设备。如果为字符串,则解释为设备名称。如果为字典,则解释为从现有设备到目标设备的映射。

返回

已移动的导出程序。

返回类型

ExportedProgram