torch.export#
创建于: 2025年6月12日 | 最后更新于: 2025年8月11日
概述#
torch.export.export()
接收一个 torch.nn.Module
并以 Ahead-of-Time (AOT) 的方式生成一个仅表示函数 Tensor 计算的跟踪图,该图随后可以执行不同的输出或序列化。
import torch
from torch.export import export, ExportedProgram
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: ExportedProgram = export(Mod(), args=example_args)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# File: /tmp/ipykernel_647/2550508656.py:6 in forward, code: a = torch.sin(x)
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x); x = None
# File: /tmp/ipykernel_647/2550508656.py:7 in forward, code: b = torch.cos(y)
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y); y = None
# File: /tmp/ipykernel_647/2550508656.py:8 in forward, code: return a + b
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos); sin = cos = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {}
torch.export
生成一个具有以下不变式的清晰的中间表示 (IR)。有关 IR 的更多规范可以在 这里 找到。
可靠性: 保证是对原始程序的可靠表示,并保持原始程序的调用约定。
标准化: 图中没有 Python 语义。原始程序中的子模块被内联,形成一个完全展平的计算图。
图属性: 该图是纯函数式的,意味着它不包含副作用操作,如突变或别名。它不会修改任何中间值、参数或缓冲区。
元数据: 图包含在跟踪过程中捕获的元数据,例如来自用户代码的堆栈跟踪。
底层,torch.export
利用了以下最新技术
TorchDynamo (torch._dynamo) 是一个内部 API,它使用名为 Frame Evaluation API 的 CPython 功能来安全地跟踪 PyTorch 图。这提供了大大改进的图捕获体验,只需要进行更少的重写即可完全跟踪 PyTorch 代码。
AOT Autograd 提供了一个函数化的 PyTorch 图,并确保图被分解/降低到 ATen 运算符集。
Torch FX (torch.fx) 是图的基础表示,允许灵活的基于 Python 的转换。
现有框架#
torch.compile()
也利用了与 torch.export
相同的 PT2 堆栈,但略有不同
JIT vs. AOT:
torch.compile()
是一个 JIT 编译器,而torch.export
是一个 AOT 编译器,后者不打算用于在部署之外生成编译后的工件。部分 vs. 完整图捕获: 当
torch.compile()
遇到模型中无法跟踪的部分时,它会“图中断”并回退到在急切 Python 运行时中运行程序。相比之下,torch.export
旨在获得 PyTorch 模型的完整图表示,因此在遇到无法跟踪的内容时会报错。由于torch.export
生成了与任何 Python 功能或运行时分离的完整图,因此该图可以被保存、加载并在不同的环境和语言中运行。可用性权衡: 由于
torch.compile()
能够随时回退到 Python 运行时以处理任何无法跟踪的内容,因此它更加灵活。而torch.export
则要求用户提供更多信息或重写代码以使其可跟踪。
与 torch.fx.symbolic_trace()
相比,torch.export
使用 TorchDynamo 进行跟踪,TorchDynamo 在 Python 字节码级别运行,这使其能够跟踪任意 Python 结构,而不受 Python 运算符重载支持的限制。此外,torch.export
会精细地跟踪 Tensor 元数据,以便像 Tensor 形状上的条件不会导致跟踪失败。总的来说,torch.export
应该能处理更多用户程序,并生成更低级别的图(在 torch.ops.aten
运算符级别)。请注意,用户仍然可以使用 torch.fx.symbolic_trace()
作为 torch.export
之前的预处理步骤。
与 torch.jit.script()
相比,torch.export
不会捕获 Python 控制流或数据结构,除非使用显式的 控制流运算符,但由于其对 Python 字节码的全面覆盖,它支持更多的 Python 语言特性。生成的图更简单,只有直线控制流,除了显式的控制流运算符。
与 torch.jit.trace()
相比,torch.export
是可靠的:它可以跟踪执行 Tensor 形状上整数计算的代码,并记录所有必要 的侧条件,以确保特定跟踪对其他输入有效。
导出 PyTorch 模型#
主入口点是通过 torch.export.export()
,它接收一个 torch.nn.Module
和示例输入,并将计算图捕获到一个 torch.export.ExportedProgram
中。一个示例
import torch
from torch.export import export, ExportedProgram
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
# To run the exported program, we can use the `module()` method
print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256)))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]); x = p_conv_weight = p_conv_bias = None
# File: /tmp/ipykernel_647/2848084713.py:16 in forward, code: a.add_(constant)
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant); conv2d = constant = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_); add_ = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/pooling.py:226 in forward, code: return F.max_pool2d(
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]); relu = None
return (max_pool2d,)
Graph signature:
# inputs
p_conv_weight: PARAMETER target='conv.weight'
p_conv_bias: PARAMETER target='conv.bias'
x: USER_INPUT
constant: USER_INPUT
# outputs
max_pool2d: USER_OUTPUT
Range constraints: {}
tensor([[[[1.3647, 2.2724, 1.9313, ..., 2.2872, 2.0532, 1.7862],
[2.3376, 1.7866, 2.0821, ..., 1.8447, 2.5090, 1.4548],
[2.1719, 2.0520, 1.6153, ..., 2.1426, 1.4364, 1.7749],
...,
[2.1165, 2.0080, 1.6747, ..., 1.8013, 2.3714, 2.1069],
[1.9574, 1.9642, 1.6017, ..., 2.2444, 1.7903, 2.1224],
[1.6735, 1.4687, 1.4179, ..., 2.1677, 1.4616, 1.9001]],
[[1.7637, 2.2845, 1.6327, ..., 1.8567, 1.6745, 1.3376],
[1.8699, 1.1547, 2.5045, ..., 1.1387, 2.0210, 1.3825],
[1.7050, 1.3393, 1.9955, ..., 1.4296, 1.8792, 1.7073],
...,
[1.5292, 1.9183, 1.8844, ..., 2.0815, 1.8089, 2.5264],
[2.0338, 2.3296, 2.1650, ..., 2.0727, 2.0166, 1.7309],
[1.8381, 1.6556, 1.9402, ..., 1.6529, 1.8134, 1.6075]],
[[1.6673, 2.1074, 1.5976, ..., 1.7421, 1.7998, 1.6087],
[2.0269, 1.3379, 1.6679, ..., 1.6671, 1.6000, 1.9894],
[2.0480, 1.5340, 1.4017, ..., 1.7944, 0.9860, 2.3785],
...,
[1.5266, 1.3949, 1.2980, ..., 1.3569, 1.9492, 1.8062],
[1.8315, 1.2293, 1.1087, ..., 1.5446, 1.6492, 1.6620],
[1.4799, 1.3720, 1.5748, ..., 1.8854, 1.2940, 1.7422]],
...,
[[1.5734, 2.2576, 1.6242, ..., 2.2690, 1.5416, 1.6914],
[1.6577, 1.1605, 1.3565, ..., 0.8677, 1.1838, 1.7662],
[2.0769, 1.6546, 1.6169, ..., 2.1301, 1.3892, 1.7564],
...,
[1.4403, 2.0147, 2.0693, ..., 1.8359, 1.3394, 1.6654],
[1.6513, 2.2535, 2.0069, ..., 1.0825, 1.6039, 1.3635],
[1.3937, 1.0050, 2.8575, ..., 1.9308, 1.4201, 1.4665]],
[[1.8365, 1.3995, 1.8011, ..., 1.7474, 2.0621, 1.7876],
[1.7984, 2.0523, 1.4683, ..., 1.7031, 2.4383, 1.2690],
[1.6586, 1.4678, 1.7569, ..., 1.4851, 1.5530, 2.1754],
...,
[1.3299, 1.2796, 1.9345, ..., 1.8214, 2.7972, 2.0472],
[1.3732, 1.3926, 1.5657, ..., 1.4157, 2.2771, 1.9893],
[1.4669, 1.7627, 1.6258, ..., 1.5350, 2.0609, 2.1192]],
[[1.8094, 2.4813, 1.9160, ..., 2.0701, 1.2966, 1.0297],
[2.2580, 1.9149, 1.8033, ..., 1.4026, 1.9687, 1.5859],
[1.3034, 1.1883, 1.5391, ..., 1.5249, 1.5463, 2.0138],
...,
[2.2308, 1.7266, 1.8078, ..., 1.7636, 2.2842, 1.5676],
[1.7690, 1.3066, 1.8903, ..., 1.8704, 1.5074, 1.5758],
[2.3337, 1.5704, 1.8364, ..., 1.4995, 1.8263, 1.8836]]]],
grad_fn=<MaxPool2DWithIndicesBackward0>)
检查 ExportedProgram
,我们可以注意到以下几点
该
torch.fx.Graph
包含原始程序的计算图,以及用于方便调试的原始代码记录。图仅包含
torch.ops.aten
运算符,这些运算符可以在 此处 找到,以及自定义运算符。参数(conv 的 weight 和 bias)被提升为图的输入,因此图中没有 `get_attr` 节点,而这些节点在
torch.fx.symbolic_trace()
的结果中存在。torch.export.ExportGraphSignature
模型化了输入和输出签名,同时指定了哪些输入是参数。图中每个节点生成的 Tensor 的形状和数据类型被记录下来。例如,`conv2d` 节点将生成一个 dtype 为 `torch.float32`,形状为 (1, 16, 256, 256) 的 Tensor。
表达动态性#
默认情况下,torch.export
会在假设所有输入形状都是**静态**的情况下跟踪程序,并将导出的程序专门化到这些维度。这带来的一个后果是,在运行时,该程序将无法处理形状不同的输入,即使它们在急切模式下是有效的。
一个例子
import torch
import traceback as tb
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
ep = torch.export.export(M(), example_args)
print(ep)
example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
try:
ep.module()(*example_args2) # fails
except Exception:
tb.print_exc()
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[32, 64]", x2: "f32[32, 128]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[32, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias); x1 = p_branch1_0_weight = p_branch1_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[32, 32]" = torch.ops.aten.relu.default(linear); linear = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[32, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias); x2 = p_branch2_0_weight = p_branch2_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu_1: "f32[32, 64]" = torch.ops.aten.relu.default(linear_1); linear_1 = None
# File: /tmp/ipykernel_647/1522925308.py:19 in forward, code: return (out1 + self.buffer, out2)
add: "f32[32, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer); relu = c_buffer = None
return (add, relu_1)
Graph signature:
# inputs
p_branch1_0_weight: PARAMETER target='branch1.0.weight'
p_branch1_0_bias: PARAMETER target='branch1.0.bias'
p_branch2_0_weight: PARAMETER target='branch2.0.weight'
p_branch2_0_bias: PARAMETER target='branch2.0.bias'
c_buffer: CONSTANT_TENSOR target='buffer'
x1: USER_INPUT
x2: USER_INPUT
# outputs
add: USER_OUTPUT
relu_1: USER_OUTPUT
Range constraints: {}
Traceback (most recent call last):
File "/tmp/ipykernel_647/1522925308.py", line 28, in <module>
ep.module()(*example_args2) # fails
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 413, in __call__
raise e
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
return inner()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1829, in inner
result = forward_call(*args, **kwargs)
File "<eval_with_key>.25", line 11, in forward
_guards_fn = self._guards_fn(x1, x2); _guards_fn = None
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "<string>", line 3, in _
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/__init__.py", line 2185, in _assert
assert condition, message
AssertionError: Guard failed: x1.size()[0] == 32
然而,某些维度,例如批次维度,可以是动态的,并且每次运行都可能不同。此类维度必须通过使用 torch.export.Dim()
API 来创建,并通过 dynamic_shapes
参数传递给 torch.export.export()
。
import torch
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
ep = torch.export.export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(ep)
example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
ep.module()(*example_args2) # success
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s24, 64]", x2: "f32[s24, 128]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s24, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias); x1 = p_branch1_0_weight = p_branch1_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[s24, 32]" = torch.ops.aten.relu.default(linear); linear = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s24, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias); x2 = p_branch2_0_weight = p_branch2_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu_1: "f32[s24, 64]" = torch.ops.aten.relu.default(linear_1); linear_1 = None
# File: /tmp/ipykernel_647/3456136871.py:18 in forward, code: return (out1 + self.buffer, out2)
add: "f32[s24, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer); relu = c_buffer = None
return (add, relu_1)
Graph signature:
# inputs
p_branch1_0_weight: PARAMETER target='branch1.0.weight'
p_branch1_0_bias: PARAMETER target='branch1.0.bias'
p_branch2_0_weight: PARAMETER target='branch2.0.weight'
p_branch2_0_bias: PARAMETER target='branch2.0.bias'
c_buffer: CONSTANT_TENSOR target='buffer'
x1: USER_INPUT
x2: USER_INPUT
# outputs
add: USER_OUTPUT
relu_1: USER_OUTPUT
Range constraints: {s24: VR[0, int_oo]}
(tensor([[1.0000, 1.8873, 1.2537, ..., 1.8441, 1.5857, 1.2611],
[1.0000, 1.0000, 1.0598, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 2.3600, ..., 1.8214, 1.0710, 1.0000],
...,
[1.0000, 1.0516, 1.1907, ..., 1.4364, 1.0000, 1.0000],
[1.0000, 1.1808, 1.0000, ..., 1.0000, 1.0000, 1.3670],
[1.5627, 2.3273, 1.0000, ..., 1.0000, 1.0000, 1.6566]],
grad_fn=<AddBackward0>),
tensor([[0.0000, 0.0000, 0.0000, ..., 0.1734, 0.0999, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.4386, 0.7500],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[1.1618, 0.0000, 0.0000, ..., 0.0000, 0.1554, 0.0000],
[0.0000, 0.4619, 0.0904, ..., 0.0000, 0.0792, 0.5539],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
grad_fn=<ReluBackward0>))
一些额外的注意事项
通过
torch.export.Dim()
API 和dynamic_shapes
参数,我们指定了每个输入的第一个维度是动态的。查看输入 `x1` 和 `x2`,它们具有符号形状(s0, 64)
和(s0, 128)
,而不是我们作为示例输入传入的形状为(32, 64)
和(32, 128)
的 Tensor。`s0` 是一个符号,代表该维度可以是一系列值。exported_program.range_constraints
描述了图中出现的每个符号的范围。在这种情况下,我们看到 `s0` 的范围是 [0, int_oo]。出于难以在此处解释的技术原因,它们假定不为 0 或 1。这不是 bug,也不一定意味着导出的程序不适用于维度 0 或 1。请参阅 0/1 特化问题 以深入讨论此主题。
在此示例中,我们使用了 Dim("batch")
来创建一个动态维度。这是指定动态性的最明确的方法。我们也可以使用 Dim.DYNAMIC
和 Dim.AUTO
来指定动态性。我们将在下一节中介绍这两种方法。
命名维度#
对于用 Dim("name")
指定的每个维度,我们将分配一个符号形状。使用相同名称指定 Dim
将导致生成相同的符号。这允许用户指定为每个输入维度分配了哪些符号。
batch = Dim("batch")
dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}}
对于每个 Dim
,我们可以指定最小值和最大值。我们也允许在单变量线性表达式中指定 Dim
之间的关系:A * dim + B
。这允许用户为动态维度指定更复杂的约束,如整数可除性。这些功能使用户能够对生成的 ExportedProgram
的动态行为施加明确的限制。
dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
"x": (dx, None),
"y": (2 * dx, dh),
}
但是,如果在跟踪过程中发出与给定关系或静态/动态规范冲突的保护,将引发 ConstraintViolationErrors
。例如,在上述规范中,断言如下
x.shape[0]
的范围是[4, 256]
,并且与y.shape[0]
的关系是y.shape[0] == 2 * x.shape[0]
。x.shape[1]
是静态的。y.shape[1]
的范围是[0, 512]
,并且与任何其他维度无关。
如果在跟踪过程中发现任何这些断言不正确(例如,`x.shape[0]` 是静态的,或者 `y.shape[1]` 的范围更小,或者 `y.shape[0] != 2 * x.shape[0]`),则将引发 ConstraintViolationError
,用户需要更改其 dynamic_shapes
规范。
维度提示#
而不是使用 Dim("name")
显式指定动态性,我们可以让 torch.export
使用 Dim.DYNAMIC
来推断动态值的范围和关系。当您不确定动态值具体动态到什么程度时,这也是一种更方便的指定动态性的方法。
dynamic_shapes = {
"x": (Dim.DYNAMIC, None),
"y": (Dim.DYNAMIC, Dim.DYNAMIC),
}
我们还可以为 Dim.DYNAMIC
指定 min/max 值,这些值将作为导出的提示。但如果在跟踪过程中导出发现范围不同,它将自动更新范围而不会引发错误。我们也无法指定动态值之间的关系。相反,这将由导出推断,并通过检查图中的断言暴露给用户。在这种指定动态性的方法中,只有当推断出的值为**静态**时,才会引发 ConstraintViolationErrors
。
指定动态性的一个更方便的方法是使用 Dim.AUTO
,它的行为类似于 Dim.DYNAMIC
,但如果推断出维度是静态的,则不会引发错误。当您对动态值的范围一无所知,并希望以“尽力而为”的动态方式导出程序时,这很有用。
ShapesCollection#
在通过 dynamic_shapes
指定哪些输入是动态的时,我们必须指定每个输入的动态性。例如,给定以下输入
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
我们需要指定 `tensor_x`、`tensor_y` 和 `tensor_z` 的动态性以及动态形状
# With named-Dims
dim = torch.export.Dim(...)
dynamic_shapes = {"x": {0: dim, 1: dim + 1}, "others": [{0: dim * 2}, None]}
torch.export(..., args, dynamic_shapes=dynamic_shapes)
然而,这特别复杂,因为我们需要以与输入参数相同的嵌套输入结构来指定 `dynamic_shapes` 规范。相反,一种更简单的指定动态形状的方法是使用辅助工具 torch.export.ShapesCollection
,其中我们不必指定每个输入的动态性,而是可以直接分配哪些输入维度是动态的。
import torch
class M(torch.nn.Module):
def forward(self, inp):
x = inp["x"] * 1
y = inp["others"][0] * 2
z = inp["others"][1] * 3
return x, y, z
tensor_x = torch.randn(3, 4, 8)
tensor_y = torch.randn(6)
tensor_z = torch.randn(6)
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
dim = torch.export.Dim("dim")
sc = torch.export.ShapesCollection()
sc[tensor_x] = (dim, dim + 1, 8)
sc[tensor_y] = {0: dim * 2}
print(sc.dynamic_shapes(M(), (args,)))
ep = torch.export.export(M(), (args,), dynamic_shapes=sc)
print(ep)
{'inp': {'x': (Dim('dim', min=0), dim + 1, 8), 'others': [{0: 2*dim}, None]}}
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, inp_x: "f32[s96, s96 + 1, 8]", inp_others_0: "f32[2*s96]", inp_others_1: "f32[6]"):
# File: /tmp/ipykernel_647/1070110726.py:5 in forward, code: x = inp["x"] * 1
mul: "f32[s96, s96 + 1, 8]" = torch.ops.aten.mul.Tensor(inp_x, 1); inp_x = None
# File: /tmp/ipykernel_647/1070110726.py:6 in forward, code: y = inp["others"][0] * 2
mul_1: "f32[2*s96]" = torch.ops.aten.mul.Tensor(inp_others_0, 2); inp_others_0 = None
# File: /tmp/ipykernel_647/1070110726.py:7 in forward, code: z = inp["others"][1] * 3
mul_2: "f32[6]" = torch.ops.aten.mul.Tensor(inp_others_1, 3); inp_others_1 = None
return (mul, mul_1, mul_2)
Graph signature:
# inputs
inp_x: USER_INPUT
inp_others_0: USER_INPUT
inp_others_1: USER_INPUT
# outputs
mul: USER_OUTPUT
mul_1: USER_OUTPUT
mul_2: USER_OUTPUT
Range constraints: {s96: VR[0, int_oo], s96 + 1: VR[1, int_oo], 2*s96: VR[0, int_oo]}
AdditionalInputs#
如果您不知道输入的动态性有多大,但有一组丰富的测试或性能分析数据,可以提供对模型代表性输入的合理了解,您可以使用 torch.export.AdditionalInputs
代替 dynamic_shapes
。您可以指定用于跟踪程序的所有可能的输入,并且 AdditionalInputs
将根据输入形状的变化来推断哪些输入是动态的。
示例
import dataclasses
import torch
import torch.utils._pytree as pytree
@dataclasses.dataclass
class D:
b: bool
i: int
f: float
t: torch.Tensor
pytree.register_dataclass(D)
class M(torch.nn.Module):
def forward(self, d: D):
return d.i + d.f + d.t
input1 = (D(True, 3, 3.0, torch.ones(3)),)
input2 = (D(True, 4, 3.0, torch.ones(4)),)
ai = torch.export.AdditionalInputs()
ai.add(input1)
ai.add(input2)
print(ai.dynamic_shapes(M(), input1))
ep = torch.export.export(M(), input1, dynamic_shapes=ai)
print(ep)
{'d': [None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True),)]}
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, d_b, d_i: "Sym(s37)", d_f, d_t: "f32[s99]"):
# File: /tmp/ipykernel_647/829931439.py:16 in forward, code: return d.i + d.f + d.t
sym_float: "Sym(ToFloat(s37))" = torch.sym_float(d_i); d_i = None
add: "Sym(ToFloat(s37) + 3.0)" = sym_float + 3.0; sym_float = None
add_1: "f32[s99]" = torch.ops.aten.add.Tensor(d_t, add); d_t = add = None
return (add_1,)
Graph signature:
# inputs
d_b: USER_INPUT
d_i: USER_INPUT
d_f: USER_INPUT
d_t: USER_INPUT
# outputs
add_1: USER_OUTPUT
Range constraints: {s37: VR[0, int_oo], s99: VR[2, int_oo]}
序列化#
要保存 ExportedProgram
,用户可以使用 torch.export.save()
和 torch.export.load()
API。生成的文件是一个具有特定结构的 zip 文件。结构的详细信息在 PT2 Archive Spec 中定义。
一个例子
import torch
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), (torch.randn(5),))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
导出 IR,分解#
torch.export
生成的图返回一个仅包含 ATen 运算符 的图,ATen 运算符是 PyTorch 中的基本计算单元。由于有超过 3000 个 ATen 运算符,导出提供了一种根据某些特征缩小图中使用 的运算符集的方法,从而创建不同的 IR。
默认情况下,导出生成最通用的 IR,其中包含所有 ATen 运算符,包括功能性和非功能性运算符。功能性运算符是没有突变或别名的输入运算符。您可以在 此处 找到所有 ATen 运算符的列表,并且可以通过检查 `op._schema.is_mutable` 来检查运算符是否为功能性的。
此通用 IR 可用于在急切 PyTorch Autograd 中进行训练。
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = add_ = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
return (batch_norm,)
然而,如果您想将 IR 用于推理,或减少使用的运算符数量,您可以通过 `ExportedProgram.run_decompositions()` API 将图进行降低。此方法将 ATen 运算符分解为分解表中指定的运算符,并将图进行函数化。
通过指定一个空集,我们只执行函数化,而不进行任何额外的分解。这会生成一个包含约 2000 个运算符(而不是上述 3000 个运算符)的 IR,这对于推理场景非常理想。
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
我们可以看到,以前的就地运算符 `torch.ops.aten.add_.default` 现在已被替换为 `torch.ops.aten.add.default`,这是一个功能性运算符。
我们还可以将此导出的程序进一步降低到仅包含 核心 ATen 运算符集 的运算符集,这是一个约 180 个运算符的集合。此 IR 最适合不想重新实现所有 ATen 运算符的后端。
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
core_aten_ir = ep_for_training.run_decompositions(decomp_table=None)
print(core_aten_ir.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); convolution = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
我们现在看到 `torch.ops.aten.conv2d.default` 已被分解为 `torch.ops.aten.convolution.default`。这是因为 `convolution` 是一个更“核心”的运算符,因为 `conv1d` 和 `conv2d` 等运算可以使用相同的 op 来实现。
我们也可以指定自己的分解行为
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
my_decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2); convolution = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
请注意,`torch.ops.aten.conv2d.default` 没有被分解为 `torch.ops.aten.convolution.default`,而是被分解为 `torch.ops.aten.convolution.default` 和 `torch.ops.aten.mul.Tensor`,这符合我们的自定义分解规则。
torch.export 的局限性#
由于 torch.export
是一个从 PyTorch 程序捕获计算图的一次性过程,它最终可能会遇到程序中无法跟踪的部分,因为几乎不可能支持跟踪所有 PyTorch 和 Python 功能。在 torch.compile
的情况下,不支持的操作会导致“图中断”,并且不支持的操作将使用默认的 Python 评估来运行。相比之下,torch.export
会要求用户提供额外信息或重写部分代码以使其可跟踪。
Draft-export 是一个很好的资源,用于列出跟踪程序时会遇到的图中断,以及解决这些错误的额外调试信息。
ExportDB 也是一个很好的资源,用于了解支持和不支持的程序类型,以及重写程序以使其可跟踪的方法。
TorchDynamo 不支持#
当使用 `strict=True` 的 torch.export
时,它将使用 TorchDynamo 在 Python 字节码级别评估程序以将程序跟踪到图。与以前的跟踪框架相比,需要更少的重写即可使程序可跟踪,但仍会存在一些不支持的 Python 功能。为了绕过处理图中断的方法是使用 非严格导出,通过将 `strict` 标志更改为 `strict=False`。
数据/形状依赖的控制流#
当形状未专门化时,数据依赖的控制流(`if x.shape[0] > 2`)也可能遇到图中断,因为跟踪编译器无法处理,除非生成组合爆炸式路径的代码。在这种情况下,用户需要使用特殊的控制流运算符重写其代码。目前,我们支持 torch.cond 来表示 if-else 类似的控制流(更多内容即将推出!)。
您还可以参考此 教程 以了解更多解决数据依赖错误的方法。