torch.export#
创建日期:2025年6月12日 | 最后更新日期:2025年12月3日
概述#
torch.export.export() 接收一个 torch.nn.Module 并以预编译(AOT)的方式生成一个表示函数仅张量计算的跟踪图,之后可以对该图执行不同的输出或进行序列化。
import torch
from torch.export import export, ExportedProgram
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: ExportedProgram = export(Mod(), args=example_args)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# File: /tmp/ipykernel_558/2550508656.py:6 in forward, code: a = torch.sin(x)
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x); x = None
# File: /tmp/ipykernel_558/2550508656.py:7 in forward, code: b = torch.cos(y)
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y); y = None
# File: /tmp/ipykernel_558/2550508656.py:8 in forward, code: return a + b
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos); sin = cos = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {}
torch.export 生成一个具有以下不变性的干净中间表示(IR)。更多关于IR的规范可以在 此处 找到。
正确性:保证是对原始程序的正确表示,并保持与原始程序相同的调用约定。
规范化:图中没有Python语义。原始程序中的子模块被内联,形成一个完全展平的计算图。
图属性:默认情况下,图可能同时包含函数式和非函数式运算符(包括突变)。要获得纯函数式图,请使用
run_decompositions(),它会移除突变和别名。元数据:图包含在跟踪过程中捕获的元数据,例如用户代码中的堆栈跟踪。
在底层,torch.export 利用了以下最新技术:
TorchDynamo (torch._dynamo) 是一个内部API,它使用名为Frame Evaluation API的CPython特性来安全地跟踪PyTorch图。这提供了大大改进的图捕获体验,并且需要更少的重写来完全跟踪PyTorch代码。
AOT Autograd 确保图被分解/降级到ATen运算符集。当使用
run_decompositions()时,它还可以提供函数化。Torch FX (torch.fx) 是图的底层表示,允许灵活的基于Python的转换。
现有框架#
torch.compile() 也利用了与 torch.export 相同的PT2堆栈,但略有不同:
JIT vs. AOT:
torch.compile()是一个JIT编译器,而torch.export旨在生成可部署的编译工件。部分 vs. 完全图捕获:当
torch.compile()遇到模型中不可跟踪的部分时,它会“图中断”并回退到在易失性Python运行时中执行程序。相比之下,torch.export旨在获得PyTorch模型的完整图表示,因此当遇到不可跟踪的内容时会报错。由于torch.export生成的图与任何Python特性或运行时分离,因此该图可以保存、加载并在不同的环境和语言中运行。可用性权衡:由于
torch.compile()能够在遇到不可跟踪内容时回退到Python运行时,因此它更加灵活。而torch.export则要求用户提供更多信息或重写代码以使其可跟踪。
与 torch.fx.symbolic_trace() 相比,torch.export 使用TorchDynamo进行跟踪,TorchDynamo在Python字节码级别运行,使其能够跟踪任意Python构造,而不局限于Python运算符重载支持的内容。此外,torch.export 精细地跟踪张量元数据,因此诸如张量形状之类的条件不会导致跟踪失败。总的来说,torch.export 预计会处理更多的用户程序,并生成更底层的图(在 torch.ops.aten 运算符级别)。请注意,用户仍然可以使用 torch.fx.symbolic_trace() 作为 torch.export 之前的预处理步骤。
与 torch.jit.script() 相比,torch.export 不会捕获Python控制流或数据结构(除非使用显式的 控制流运算符),但由于其对Python字节码的全面覆盖,它支持更多的Python语言特性。生成的图更简单,只有直线控制流,除了显式的控制流运算符。
与 torch.jit.trace() 相比,torch.export 是可靠的:它可以跟踪对大小执行整数计算的代码,并记录所有必要的副作用条件,以确保特定跟踪对其他输入有效。
导出PyTorch模型#
主要入口点是通过 torch.export.export(),它接收一个 torch.nn.Module 和示例输入,并将计算图捕获到 torch.export.ExportedProgram 中。示例如下:
import torch
from torch.export import export, ExportedProgram
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
# To run the exported program, we can use the `module()` method
print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256)))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:553 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]); x = p_conv_weight = p_conv_bias = None
# File: /tmp/ipykernel_558/2848084713.py:16 in forward, code: a.add_(constant)
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant); conv2d = constant = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:143 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_); add_ = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/pooling.py:224 in forward, code: return F.max_pool2d(
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]); relu = None
return (max_pool2d,)
Graph signature:
# inputs
p_conv_weight: PARAMETER target='conv.weight'
p_conv_bias: PARAMETER target='conv.bias'
x: USER_INPUT
constant: USER_INPUT
# outputs
max_pool2d: USER_OUTPUT
Range constraints: {}
tensor([[[[1.2826, 1.5580, 1.2994, ..., 1.7939, 1.7223, 1.9093],
[1.9136, 1.8306, 1.9733, ..., 1.8196, 1.5307, 1.7811],
[1.3379, 2.3498, 1.6252, ..., 1.2179, 1.3837, 1.3712],
...,
[1.5983, 1.8066, 2.0805, ..., 1.4232, 1.9107, 2.3247],
[1.8269, 2.1419, 2.0269, ..., 1.5310, 1.6102, 1.7968],
[1.6521, 1.4793, 1.5691, ..., 1.3915, 1.5248, 1.4336]],
[[2.1476, 1.7757, 1.6806, ..., 1.6925, 2.7682, 1.8764],
[2.0719, 2.3811, 2.1348, ..., 1.5409, 1.9257, 2.2841],
[2.1350, 2.0278, 1.9268, ..., 1.9617, 2.2835, 2.2297],
...,
[1.6152, 1.3856, 1.8618, ..., 2.8097, 1.3546, 1.9110],
[1.9449, 2.0853, 2.1837, ..., 1.1103, 2.9785, 1.7965],
[2.2465, 2.4192, 1.7677, ..., 1.8937, 2.0854, 1.7335]],
[[1.6582, 1.6809, 1.5678, ..., 1.4650, 2.5201, 1.6977],
[2.5051, 2.1685, 2.6054, ..., 2.1139, 1.7916, 2.0193],
[2.2462, 1.6973, 1.9110, ..., 2.1940, 2.2802, 1.8960],
...,
[1.6740, 1.5380, 1.9010, ..., 2.4108, 1.6413, 1.6824],
[2.1880, 1.8369, 2.1543, ..., 1.8302, 1.9309, 2.5490],
[2.2255, 2.6299, 2.4285, ..., 1.8004, 1.9514, 1.6198]],
...,
[[1.6318, 1.4941, 1.6787, ..., 2.0484, 1.1358, 2.0613],
[2.0916, 1.9009, 2.3360, ..., 2.8306, 1.5998, 2.2715],
[1.9402, 1.8655, 2.0556, ..., 1.8865, 1.3384, 1.5239],
...,
[2.0466, 1.6921, 1.8156, ..., 1.5643, 2.1106, 1.6655],
[1.3279, 1.8963, 1.8651, ..., 1.7580, 1.5840, 1.7254],
[1.2859, 1.6108, 2.8042, ..., 1.7552, 2.3246, 2.1089]],
[[1.7599, 1.4617, 1.3991, ..., 1.5102, 1.8866, 1.9577],
[1.5342, 1.5302, 1.5428, ..., 1.9546, 1.7829, 1.7457],
[1.5917, 2.3962, 1.8921, ..., 1.5050, 1.5800, 1.8842],
...,
[1.6564, 1.0485, 2.4266, ..., 1.0107, 1.2459, 1.6084],
[1.7559, 1.4171, 1.5450, ..., 1.8679, 1.6174, 1.0972],
[1.4823, 2.0421, 1.7609, ..., 1.3987, 1.8046, 1.4119]],
[[1.5778, 1.9460, 1.7058, ..., 1.8046, 1.7689, 1.5508],
[1.4104, 2.0936, 1.6594, ..., 2.2562, 1.7512, 2.0105],
[2.0650, 2.5356, 2.0778, ..., 1.6606, 1.7495, 1.8828],
...,
[1.8065, 1.9066, 1.3303, ..., 1.2468, 1.7625, 1.7471],
[1.6914, 1.3748, 2.0785, ..., 1.6700, 2.2404, 1.6258],
[2.3748, 1.9729, 1.7573, ..., 1.5317, 2.2878, 1.7637]]]],
grad_fn=<MaxPool2DWithIndicesBackward0>)
检查 ExportedProgram,我们可以注意到以下几点:
torch.fx.Graph包含原始程序的计算图,以及用于方便调试的原始代码记录。图仅包含
torch.ops.aten运算符(可在 此处 找到)和自定义运算符。参数(conv的weight和bias)被提升为图的输入,导致图中没有
get_attr节点,这些节点先前存在于torch.fx.symbolic_trace()的结果中。torch.export.ExportGraphSignature模拟输入和输出签名,并指定哪些输入是参数。图中每个节点生成的张量的形状和dtype都会被记录。例如,
conv2d节点将产生一个dtype为torch.float32、形状为 (1, 16, 256, 256) 的张量。
表达动态性#
默认情况下,torch.export 将在假设所有输入形状都是**静态**的情况下跟踪程序,并将导出的程序专门化到这些维度。这样做的后果是,在运行时,程序将无法处理不同形状的输入,即使它们在易失性模式下是有效的。
一个例子
import torch
import traceback as tb
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
ep = torch.export.export(M(), example_args)
print(ep)
example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
try:
ep.module()(*example_args2) # fails
except Exception:
tb.print_exc()
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[32, 64]", x2: "f32[32, 128]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[32, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias); x1 = p_branch1_0_weight = p_branch1_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:143 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[32, 32]" = torch.ops.aten.relu.default(linear); linear = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[32, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias); x2 = p_branch2_0_weight = p_branch2_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:143 in forward, code: return F.relu(input, inplace=self.inplace)
relu_1: "f32[32, 64]" = torch.ops.aten.relu.default(linear_1); linear_1 = None
# File: /tmp/ipykernel_558/1522925308.py:19 in forward, code: return (out1 + self.buffer, out2)
add: "f32[32, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer); relu = c_buffer = None
return (add, relu_1)
Graph signature:
# inputs
p_branch1_0_weight: PARAMETER target='branch1.0.weight'
p_branch1_0_bias: PARAMETER target='branch1.0.bias'
p_branch2_0_weight: PARAMETER target='branch2.0.weight'
p_branch2_0_bias: PARAMETER target='branch2.0.bias'
c_buffer: CONSTANT_TENSOR target='buffer'
x1: USER_INPUT
x2: USER_INPUT
# outputs
add: USER_OUTPUT
relu_1: USER_OUTPUT
Range constraints: {}
Traceback (most recent call last):
File "/tmp/ipykernel_558/1522925308.py", line 28, in <module>
ep.module()(*example_args2) # fails
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 936, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 455, in __call__
raise e
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 442, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
return inner()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1830, in inner
result = forward_call(*args, **kwargs)
File "<eval_with_key>.25", line 11, in forward
_guards_fn = self._guards_fn(x1, x2); _guards_fn = None
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 216, in inner
return func(*args, **kwargs)
File "<string>", line 3, in _
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/__init__.py", line 2228, in _assert
assert condition, message
AssertionError: Guard failed: x1.size()[0] == 32
然而,某些维度,例如批量维度,可以是动态的,并且每次运行都可能不同。必须通过使用 torch.export.Dim() API创建它们,并通过 dynamic_shapes 参数将它们传递给 torch.export.export() 来指定这些维度。
import torch
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
ep = torch.export.export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(ep)
example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
ep.module()(*example_args2) # success
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s24, 64]", x2: "f32[s24, 128]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s24, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias); x1 = p_branch1_0_weight = p_branch1_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:143 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[s24, 32]" = torch.ops.aten.relu.default(linear); linear = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s24, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias); x2 = p_branch2_0_weight = p_branch2_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:143 in forward, code: return F.relu(input, inplace=self.inplace)
relu_1: "f32[s24, 64]" = torch.ops.aten.relu.default(linear_1); linear_1 = None
# File: /tmp/ipykernel_558/3456136871.py:18 in forward, code: return (out1 + self.buffer, out2)
add: "f32[s24, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer); relu = c_buffer = None
return (add, relu_1)
Graph signature:
# inputs
p_branch1_0_weight: PARAMETER target='branch1.0.weight'
p_branch1_0_bias: PARAMETER target='branch1.0.bias'
p_branch2_0_weight: PARAMETER target='branch2.0.weight'
p_branch2_0_bias: PARAMETER target='branch2.0.bias'
c_buffer: CONSTANT_TENSOR target='buffer'
x1: USER_INPUT
x2: USER_INPUT
# outputs
add: USER_OUTPUT
relu_1: USER_OUTPUT
Range constraints: {s24: VR[0, int_oo]}
(tensor([[1.0000, 1.4829, 1.0000, ..., 1.3803, 1.0924, 1.0000],
[1.8506, 1.7225, 1.0000, ..., 1.0611, 1.9632, 1.2653],
[1.0000, 1.5856, 1.4536, ..., 1.0413, 1.1314, 1.2240],
...,
[1.6243, 1.0000, 1.0728, ..., 1.0000, 1.6443, 1.5273],
[1.4088, 1.1999, 1.4190, ..., 1.0488, 1.0000, 1.3267],
[1.2788, 1.0000, 1.0000, ..., 1.1839, 1.0000, 1.0000]],
grad_fn=<AddBackward0>),
tensor([[0.4871, 0.7313, 0.0000, ..., 0.0608, 0.2363, 0.0000],
[0.0000, 0.8310, 0.1911, ..., 0.0000, 0.0000, 0.1265],
[0.0254, 0.0000, 0.0000, ..., 0.2306, 0.0000, 0.0835],
...,
[0.0000, 0.0000, 0.0000, ..., 0.5119, 0.0000, 0.7939],
[0.0000, 0.1304, 0.0000, ..., 0.1387, 0.0000, 0.1798],
[0.0000, 0.2281, 0.0400, ..., 0.0000, 0.0000, 0.0000]],
grad_fn=<ReluBackward0>))
一些额外的注意事项:
通过
torch.export.Dim()API和dynamic_shapes参数,我们指定了每个输入的第一个维度是动态的。查看输入x1和x2,它们具有符号形状(s0, 64)和(s0, 128),而不是我们作为示例输入传递的形状为(32, 64)和(32, 128)的张量。s0是一个符号,表示该维度可以是值的范围。exported_program.range_constraints描述了图中出现的每个符号的范围。在本例中,我们看到s0的范围是 [0, int_oo]。出于此处难以解释的技术原因,它们被假定不为0或1。这不是一个bug,也不一定意味着导出的程序对于维度0或1将无法工作。有关此主题的深入讨论,请参阅 0/1专门化问题。
在示例中,我们使用了 Dim("batch") 来创建一个动态维度。这是指定动态性最明确的方式。我们还可以使用 Dim.DYNAMIC 和 Dim.AUTO 来指定动态性。我们将在下一节中介绍这两种方法。
命名维度#
对于使用 Dim("name") 指定的每个维度,我们将分配一个符号形状。使用相同名称指定 Dim 将生成相同的符号。这允许用户指定为每个输入维度分配什么符号。
batch = Dim("batch")
dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}}
对于每个 Dim,我们可以指定最小值和最大值。我们还允许在单变量线性表达式中指定 Dim 之间的关系:A * dim + B。这使得用户能够为导出的程序(ExportedProgram)的动态行为指定更复杂的约束,例如整数可除性。
dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
"x": (dx, None),
"y": (2 * dx, dh),
}
但是,如果在跟踪过程中,我们发出的保护措施与给定的关系或静态/动态规范冲突,则会引发 ConstraintViolationErrors。例如,在上述规范中,以下内容被断言:
x.shape[0]的范围为[4, 256],并与y.shape[0]相关,关系为y.shape[0] == 2 * x.shape[0]。x.shape[1]是静态的。y.shape[1]的范围为[0, 512],并且与其他任何维度无关。
如果在跟踪过程中发现这些断言中的任何一个不正确(例如,x.shape[0] 是静态的,或者 y.shape[1] 的范围更小,或者 y.shape[0] != 2 * x.shape[0]),则会引发 ConstraintViolationError,用户需要更改其 dynamic_shapes 规范。
维度提示#
而不是通过 Dim("name") 显式指定动态性,我们可以让 torch.export 使用 Dim.DYNAMIC 推断动态值的范围和关系。这也是在不确切知道动态值有多动态的情况下指定动态性的更便捷方式。
dynamic_shapes = {
"x": (Dim.DYNAMIC, None),
"y": (Dim.DYNAMIC, Dim.DYNAMIC),
}
我们还可以为 Dim.DYNAMIC 指定最小/最大值,它们将作为提示提供给 export。但是,如果在跟踪过程中 export 发现范围不同,它将自动更新范围而不会引发错误。我们也无法指定动态值之间的关系。相反,这将由 export 推断,并通过检查图中的断言暴露给用户。在这种指定动态性的方法中,只有当推断出的值是**静态**时,才会引发 ConstraintViolationErrors。
指定动态性的更便捷的方法是使用 Dim.AUTO,它将像 Dim.DYNAMIC 一样工作,但如果推断出的维度是静态的,则不会引发错误。这对于您不确定动态值是什么,并且希望以“尽力而为”的动态方法导出程序的情况很有用。
ShapesCollection#
在通过 dynamic_shapes 指定哪些输入是动态的时,我们必须指定每个输入的动态性。例如,给定以下输入:
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
我们需要 along with the dynamic shapes 指定 tensor_x、tensor_y 和 tensor_z 的动态性:
# With named-Dims
dim = torch.export.Dim(...)
dynamic_shapes = {"x": {0: dim, 1: dim + 1}, "others": [{0: dim * 2}, None]}
torch.export(..., args, dynamic_shapes=dynamic_shapes)
然而,这非常复杂,因为我们需要以与输入参数相同的嵌套结构指定 dynamic_shapes 规范。相反,一种更简单的指定动态形状的方法是使用辅助实用程序 torch.export.ShapesCollection,在该实用程序中,我们不必指定每个输入的动态性,而是可以直接分配哪些输入维度是动态的。
import torch
class M(torch.nn.Module):
def forward(self, inp):
x = inp["x"] * 1
y = inp["others"][0] * 2
z = inp["others"][1] * 3
return x, y, z
tensor_x = torch.randn(3, 4, 8)
tensor_y = torch.randn(6)
tensor_z = torch.randn(6)
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
dim = torch.export.Dim("dim")
sc = torch.export.ShapesCollection()
sc[tensor_x] = (dim, dim + 1, 8)
sc[tensor_y] = {0: dim * 2}
print(sc.dynamic_shapes(M(), (args,)))
ep = torch.export.export(M(), (args,), dynamic_shapes=sc)
print(ep)
{'inp': {'x': (Dim('dim', min=0), dim + 1, 8), 'others': [{0: 2*dim}, None]}}
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, inp_x: "f32[s96, s96 + 1, 8]", inp_others_0: "f32[2*s96]", inp_others_1: "f32[6]"):
# File: /tmp/ipykernel_558/1070110726.py:5 in forward, code: x = inp["x"] * 1
mul: "f32[s96, s96 + 1, 8]" = torch.ops.aten.mul.Tensor(inp_x, 1); inp_x = None
# File: /tmp/ipykernel_558/1070110726.py:6 in forward, code: y = inp["others"][0] * 2
mul_1: "f32[2*s96]" = torch.ops.aten.mul.Tensor(inp_others_0, 2); inp_others_0 = None
# File: /tmp/ipykernel_558/1070110726.py:7 in forward, code: z = inp["others"][1] * 3
mul_2: "f32[6]" = torch.ops.aten.mul.Tensor(inp_others_1, 3); inp_others_1 = None
return (mul, mul_1, mul_2)
Graph signature:
# inputs
inp_x: USER_INPUT
inp_others_0: USER_INPUT
inp_others_1: USER_INPUT
# outputs
mul: USER_OUTPUT
mul_1: USER_OUTPUT
mul_2: USER_OUTPUT
Range constraints: {s96: VR[0, int_oo], s96 + 1: VR[1, int_oo], 2*s96: VR[0, int_oo]}
AdditionalInputs#
如果您不确定输入的动态性,但拥有一系列可观的测试或分析数据,可以提供对模型代表性输入的合理感觉,那么您可以使用 torch.export.AdditionalInputs 来代替 dynamic_shapes。您可以指定用于跟踪程序的所有可能的输入,AdditionalInputs 将根据输入形状的变化来推断哪些输入是动态的。
示例
import dataclasses
import torch
import torch.utils._pytree as pytree
@dataclasses.dataclass
class D:
b: bool
i: int
f: float
t: torch.Tensor
pytree.register_dataclass(D)
class M(torch.nn.Module):
def forward(self, d: D):
return d.i + d.f + d.t
input1 = (D(True, 3, 3.0, torch.ones(3)),)
input2 = (D(True, 4, 3.0, torch.ones(4)),)
ai = torch.export.AdditionalInputs()
ai.add(input1)
ai.add(input2)
print(ai.dynamic_shapes(M(), input1))
ep = torch.export.export(M(), input1, dynamic_shapes=ai)
print(ep)
{'d': [None, DimHint(DYNAMIC), None, (DimHint(DYNAMIC),)]}
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, d_b, d_i: "Sym(s37)", d_f, d_t: "f32[s99]"):
# File: /tmp/ipykernel_558/829931439.py:16 in forward, code: return d.i + d.f + d.t
sym_float: "Sym(ToFloat(s37))" = torch.sym_float(d_i); d_i = None
add: "Sym(ToFloat(s37) + 3.0)" = sym_float + 3.0; sym_float = None
add_1: "f32[s99]" = torch.ops.aten.add.Tensor(d_t, add); d_t = add = None
return (add_1,)
Graph signature:
# inputs
d_b: USER_INPUT
d_i: USER_INPUT
d_f: USER_INPUT
d_t: USER_INPUT
# outputs
add_1: USER_OUTPUT
Range constraints: {s37: VR[0, int_oo], s99: VR[2, int_oo]}
序列化#
要保存 ExportedProgram,用户可以使用 torch.export.save() 和 torch.export.load() API。生成的文件是一个具有特定结构的zip文件。结构的详细信息在 PT2 Archive Spec 中定义。
一个例子
import torch
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), (torch.randn(5),))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
导出IR:训练 vs 推理#
由 torch.export 生成的图包含仅有的 ATen 运算符,它们是PyTorch中的基本计算单元。Export根据您的用例提供不同的IR级别:
IR 类型 |
如何获得 |
属性 |
运算符计数 |
用例 |
|---|---|---|---|---|
训练IR |
|
可能包含突变 |
~3000 |
使用 autograd 进行训练 |
推理IR |
|
纯函数式 |
~2000 |
推理部署 |
核心ATen IR |
|
纯函数式,高度分解 |
~180 |
最少的后端支持 |
训练IR(默认)#
默认情况下,export 生成的**训练IR**包含所有ATen运算符,包括函数式和非函数式(可变)运算符。函数式运算符是指不包含任何突变或输入别名的运算符,而非函数式运算符可能会就地修改其输入。您可以在 此处 找到所有ATen运算符的列表,您可以通过检查 op._schema.is_mutable 来判断一个运算符是否是函数式的。
这个可能包含突变的训练IR是为训练用例设计的,可以与易失性PyTorch Autograd一起使用。
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:553 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:174 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = add_ = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:194 in forward, code: return F.batch_norm(
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
return (batch_norm,)
推理IR(通过run_decompositions)#
要获得适合部署的**推理IR**,请使用 ExportedProgram.run_decompositions() API。此方法会自动:
函数化图(移除所有突变并将其转换为函数式等价物)
根据提供的分解表选择性地分解ATen运算符
这会生成一个纯函数式图,非常适合推理场景。
通过指定一个空的分解表(decomp_table={}),您可以获得纯函数化而无需额外的分解。这会生成一个包含约2000个函数式运算符的推理IR(相比之下,训练IR有3000多个)。
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:553 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:174 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:194 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
/opt/conda/envs/py_3.10/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
如我们所见,先前就地(in-place)的运算符 torch.ops.aten.add_.default 现已被替换为 torch.ops.aten.add.default,这是一个函数式运算符。
核心ATen IR#
我们可以进一步将推理IR降级到 Core ATen Operator Set <https://pytorch.ac.cn/docs/stable/torch.compiler_ir.html#core-aten-ir>__,它只包含约180个运算符。这是通过将 decomp_table=None(它使用默认的分解表)传递给 run_decompositions() 来实现的。这个IR对于希望最小化其需要实现的运算符数量的后端来说是最佳的。
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
core_aten_ir = ep_for_training.run_decompositions(decomp_table=None)
print(core_aten_ir.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:553 in forward, code: return self._conv_forward(input, self.weight, self.bias)
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:174 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:194 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); convolution = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
/opt/conda/envs/py_3.10/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
我们现在看到 torch.ops.aten.conv2d.default 已被分解为 torch.ops.aten.convolution.default。这是因为 convolution 是一个更“核心”的运算符,因为像 conv1d 和 conv2d 这样的操作都可以使用相同的运算符来实现。
我们也可以指定自己的分解行为:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
my_decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:553 in forward, code: return self._conv_forward(input, self.weight, self.bias)
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2); convolution = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:174 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:194 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
/opt/conda/envs/py_3.10/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
请注意,torch.ops.aten.conv2d.default 没有被分解为 torch.ops.aten.convolution.default,而是被分解为 torch.ops.aten.convolution.default 和 torch.ops.aten.mul.Tensor,这符合我们的自定义分解规则。
torch.export的局限性#
由于 torch.export 是从PyTorch程序捕获计算图的一次性过程,它最终可能会遇到程序中不可跟踪的部分,因为几乎不可能支持跟踪所有PyTorch和Python特性。在 torch.compile 的情况下,一个不支持的操作会导致“图中断”,而该不支持的操作将以默认的Python评估方式运行。相比之下,torch.export 将要求用户提供额外信息或重写其代码的某些部分以使其可跟踪。
Draft-export 是一个很好的资源,可以列出在跟踪程序时会遇到的图中断,以及解决这些错误的额外调试信息。
ExportDB 也是一个很好的资源,用于了解支持和不支持的程序类型,以及重写程序以使其可跟踪的方法。
TorchDynamo不支持#
当使用 torch.export 和 strict=True 时,它将使用TorchDynamo在Python字节码级别评估程序以将其跟踪到图。与之前的跟踪框架相比,使程序可跟踪所需的重写会大大减少,但仍然会有一些Python特性不受支持。绕过处理图中断的一种选择是使用 非严格导出,通过将 strict 标志更改为 strict=False。
数据/形状依赖的控制流#
当形状未被专门化时,数据依赖的控制流(if x.shape[0] > 2)也可能导致图中断,因为跟踪编译器无法处理,除非生成代码以应对组合爆炸的路径。在这种情况下,用户需要使用特殊的控制流运算符重写其代码。目前,我们支持 torch.cond 来表示 if-else 类的控制流(更多功能即将推出!)。
您还可以参考此 教程 以了解更多处理数据依赖错误的方法。
运算符的Fake/Meta Kernels缺失#
在跟踪时,所有运算符都需要一个FakeTensor kernel(也称为meta kernel)。这用于推理该运算符的输入/输出形状。
有关更多详细信息,请参阅此 教程。
万一您的模型使用了还没有FakeTensor kernel实现的ATen运算符,请提交一个issue。
阅读更多#
导出用户的附加链接
PyTorch开发者的深度解析