基于torch.export的ONNX导出器#
创建日期:2025年6月10日 | 最后更新日期:2025年8月22日
概述#
利用torch.export引擎,以提前编译(AOT)的方式生成一个仅包含函数张量计算的图。生成的图(1)在函数式ATen算子集中产生规范化的算子(以及任何用户指定的自定义算子),(2)消除了所有Python控制流和数据结构(某些例外情况),并且(3)记录了一组形状约束,以证明这种规范化和控制流消除对于未来的输入是合理的,然后最终将其转换为ONNX图。
此外,在导出过程中,内存使用量也显著降低。
依赖项#
ONNX导出器依赖于额外的Python包。
可以通过pip安装它们。
pip install --upgrade onnx onnxscript
然后可以使用onnxruntime在多种处理器上执行模型。
一个简单的示例#
下面是一个使用简单多层感知器(MLP)作为示例的导出器API演示。
import torch
import torch.nn as nn
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(8, 8, bias=True)
self.fc1 = nn.Linear(8, 4, bias=True)
self.fc2 = nn.Linear(4, 2, bias=True)
self.fc3 = nn.Linear(2, 2, bias=True)
self.fc_combined = nn.Linear(8 + 8 + 8, 8, bias=True) # Combine all inputs
def forward(self, tensor_x: torch.Tensor, input_dict: dict, input_list: list):
"""
Forward method that requires all inputs:
- tensor_x: A direct tensor input.
- input_dict: A dictionary containing the tensor under the key 'tensor_x'.
- input_list: A list where the first element is the tensor.
"""
# Extract tensors from inputs
dict_tensor = input_dict['tensor_x']
list_tensor = input_list[0]
# Combine all inputs into a single tensor
combined_tensor = torch.cat([tensor_x, dict_tensor, list_tensor], dim=1)
# Process the combined tensor through the layers
combined_tensor = self.fc_combined(combined_tensor)
combined_tensor = torch.sigmoid(combined_tensor)
combined_tensor = self.fc0(combined_tensor)
combined_tensor = torch.sigmoid(combined_tensor)
combined_tensor = self.fc1(combined_tensor)
combined_tensor = torch.sigmoid(combined_tensor)
combined_tensor = self.fc2(combined_tensor)
combined_tensor = torch.sigmoid(combined_tensor)
output = self.fc3(combined_tensor)
return output
model = MLPModel()
# Example inputs
tensor_input = torch.rand((97, 8), dtype=torch.float32)
dict_input = {'tensor_x': torch.rand((97, 8), dtype=torch.float32)}
list_input = [torch.rand((97, 8), dtype=torch.float32)]
# The input_names and output_names are used to identify the inputs and outputs of the ONNX model
input_names = ['tensor_input', 'tensor_x', 'list_input_index_0']
output_names = ['output']
# Exporting the model with all required inputs
onnx_program = torch.onnx.export(model,(tensor_input, dict_input, list_input), dynamic_shapes=({0: "batch_size"},{"tensor_x": {0: "batch_size"}},[{0: "batch_size"}]), input_names=input_names, output_names=output_names, dynamo=True,)
# Check the exported ONNX model is dynamic
assert onnx_program.model.graph.inputs[0].shape == ("batch_size", 8)
assert onnx_program.model.graph.inputs[1].shape == ("batch_size", 8)
assert onnx_program.model.graph.inputs[2].shape == ("batch_size", 8)
如上面的代码所示,您只需向torch.onnx.export()
提供模型实例及其输入。导出器随后将返回一个torch.onnx.ONNXProgram
实例,其中包含导出的ONNX图以及额外的信息。
通过onnx_program.model_proto
可用的内存中模型是一个符合ONNX IR规范的onnx.ModelProto
对象。然后可以使用torch.onnx.ONNXProgram.save()
API将ONNX模型序列化为Protobuf文件。
onnx_program.save("mlp.onnx")
转换失败时#
应第二次调用torch.onnx.export()
函数,并将参数report=True
。将生成一个markdown报告,以帮助用户解决问题。
元数据#
在ONNX导出过程中,每个ONNX节点都带有元数据注解,这些注解有助于追溯其在原始PyTorch模型中的来源和上下文。这些元数据对于调试、模型检查和理解PyTorch与ONNX图之间的映射非常有用。
每个ONNX节点都会添加以下元数据字段:
namespace
一个表示节点分层命名空间的字符串,由模块/方法的堆栈跟踪组成。
示例:
__main__.SimpleAddModel/add: aten.add.Tensor
pkg.torch.onnx.class_hierarchy
一个表示通往该节点的模块层级的类名列表。
示例:
['__main__.SimpleAddModel', 'aten.add.Tensor']
pkg.torch.onnx.fx_node
原始FX节点的字符串表示,包括其名称、使用者数量、目标torch op、参数和关键字参数。
示例:
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%tensor_x, %input_dict_tensor_x, %input_list_0], 1), kwargs = {})
pkg.torch.onnx.name_scopes
一个表示PyTorch模型中节点路径的名称作用域(方法)列表。
示例:
['', 'add']
pkg.torch.onnx.stack_trace
如果可用,这是创建该节点时原始代码的堆栈跟踪。
示例
File "simpleadd.py", line 7, in forward return torch.add(x, y)
这些元数据字段存储在每个ONNX节点的metadata_props属性中,可以使用Netron或通过编程方式进行检查。
整个ONNX图具有以下metadata_props
:
pkg.torch.export.ExportedProgram.graph_signature
此属性包含原始PyTorch ExportedProgram的graph_signature的字符串表示。图签名描述了模型输入和输出的结构以及它们如何映射到ONNX图。输入被定义为
InputSpec
对象,这些对象包括输入的种类(例如,参数的InputKind.PARAMETER
,用户定义的输入的InputKind.USER_INPUT
)、参数名称、目标(可以是模型中的特定节点)以及输入是否是持久的。输出被定义为OutputSpec
对象,这些对象指定输出的种类(例如,OutputKind.USER_OUTPUT
)和参数名称。要了解有关图签名的更多信息,请参阅torch.export。
pkg.torch.export.ExportedProgram.range_constraints
此属性包含原始PyTorch ExportedProgram中存在的任何范围约束的字符串表示。范围约束指定了模型中符号形状或值的有效范围,这对于使用动态形状或符号维度的模型可能很重要。
示例:
s0: VR[2, int_oo]
,这表示输入张量的尺寸必须至少为2。要了解有关范围约束的更多信息,请参阅torch.export。
ONNX图中的每个输入值可能具有以下元数据属性:
pkg.torch.export.graph_signature.InputSpec.kind
输入的种类,由PyTorch的InputKind枚举定义。
示例值:
“USER_INPUT”: 用户提供的模型输入。
“PARAMETER”: 模型参数(例如,权重)。
“BUFFER”: 模型缓冲区(例如,BatchNorm中的运行均值)。
“CONSTANT_TENSOR”: 常量张量参数。
“CUSTOM_OBJ”: 自定义对象输入。
“TOKEN”: token输入。
pkg.torch.export.graph_signature.InputSpec.persistent
指示输入是否持久(即,是否应作为模型状态的一部分保存)。
示例值:
“True”
“False”
ONNX图中的每个输出值可能具有以下元数据属性:
pkg.torch.export.graph_signature.OutputSpec.kind
输入的种类,由PyTorch的OutputKind枚举定义。
示例值:
“USER_OUTPUT”: 用户可见的输出。
“LOSS_OUTPUT”: 损失值输出。
“BUFFER_MUTATION”: 表示缓冲区已发生变异。
“GRADIENT_TO_PARAMETER”: 参数的梯度输出。
“GRADIENT_TO_USER_INPUT”: 用户输入的梯度输出。
“USER_INPUT_MUTATION”: 表示用户输入已发生变异。
“TOKEN”: token输出。
每个已初始化的值、输入、输出都具有以下元数据:
pkg.torch.onnx.original_node_name
在PyTorch FX图中生成此值的节点的原始名称,以防该值被重命名。这有助于将初始化器追溯到其在原始模型中的来源。
示例:
fc1.weight
API参考#
- torch.onnx.export(model, args=(), f=None, *, kwargs=None, verbose=None, input_names=None, output_names=None, opset_version=None, dynamo=True, external_data=True, dynamic_shapes=None, custom_translation_table=None, report=False, optimize=True, verify=False, profile=False, dump_exported_program=False, artifacts_dir='.', fallback=False, export_params=True, keep_initializers_as_inputs=False, dynamic_axes=None, training=<TrainingMode.EVAL: 0>, operator_export_type=<OperatorExportTypes.ONNX: 0>, do_constant_folding=True, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True)[source]#
将模型导出为ONNX格式。
设置
dynamo=True
会启用新的ONNX导出逻辑,该逻辑基于torch.export.ExportedProgram
和更现代的翻译逻辑集。这是将模型导出为ONNX的推荐且默认的方式。当
dynamo=True
时:导出器尝试以下策略来获取用于转换为ONNX的ExportedProgram。
如果模型已经是ExportedProgram,则将按原样使用。
使用
torch.export.export()
并设置strict=False
。使用
torch.export.export()
并设置strict=True
。
- 参数
model (torch.nn.Module | torch.export.ExportedProgram | torch.jit.ScriptModule | torch.jit.ScriptFunction) – 要导出的模型。
args (tuple[Any, ...]) – 示例位置输入。任何非张量参数都将被硬编码到导出的模型中;任何张量参数都将成为导出模型的输入,其顺序与元组中的顺序相同。
f (str | os.PathLike | None) – 输出ONNX模型文件的路径。例如:“model.onnx”。此参数保留用于向后兼容。建议将其留空(None),而是使用返回的
torch.onnx.ONNXProgram
来序列化模型到文件。verbose (bool | None) – 是否启用详细日志记录。
input_names (Sequence[str] | None) – 要按顺序分配给图的输入节点的名称。
output_names (Sequence[str] | None) – 要按顺序分配给图的输出节点的名称。
opset_version (int | None) – 要针对的默认(ai.onnx)算子集的版本。您应根据要运行导出模型的运行时后端或编译器的支持的算子集版本设置
opset_version
。留空(默认值None
)以使用推荐版本,或参考ONNX算子文档以获取更多信息。dynamo (bool) – 是否使用
torch.export
ExportedProgram而不是TorchScript导出模型。external_data (bool) – 是否将模型权重保存为外部数据文件。对于权重过大且超过ONNX文件大小限制(2GB)的模型,这是必需的。如果设置为False,权重将与模型架构一起保存在ONNX文件中。
dynamic_shapes (dict[str, Any] | tuple[Any, ...] | list[Any] | None) – 模型输入的动态形状的字典或元组。有关更多详细信息,请参阅
torch.export.export()
。仅当dynamo为True时使用(且首选)。请注意,dynamic_shapes设计用于当模型使用dynamo=True导出时,而dynamic_axes用于dynamo=False时。custom_translation_table (dict[Callable, Callable | Sequence[Callable]] | None) – 用于模型中算子的自定义分解的字典。字典的键应为fx Node中的可调用目标(例如
torch.ops.aten.stft.default
),值应为使用ONNX Script构建该图的函数。此选项仅在dynamo为True时有效。report (bool) – 是否为导出过程生成markdown报告。此选项仅在dynamo为True时有效。
optimize (bool) – 是否优化导出的模型。此选项仅在dynamo为True时有效。默认为True。
verify (bool) – 是否使用ONNX Runtime验证导出的模型。此选项仅在dynamo为True时有效。
profile (bool) – 是否分析导出过程。此选项仅在dynamo为True时有效。
dump_exported_program (bool) – 是否将
torch.export.ExportedProgram
转储到文件。这对于调试导出器很有用。此选项仅在dynamo为True时有效。artifacts_dir (str | os.PathLike) – 用于保存调试工件(如报告和序列化的导出程序)的目录。此选项仅在dynamo为True时有效。
fallback (bool) – 如果dynamo导出器失败,是否回退到TorchScript导出器。此选项仅在dynamo为True时有效。启用回退时,建议即使提供了dynamic_shapes也设置dynamic_axes。
export_params (bool) –
当指定了“f”时:如果为false,则不会导出参数(权重)。
您也可以将其留空,并使用返回的
torch.onnx.ONNXProgram
来控制序列化模型时如何处理初始化器。keep_initializers_as_inputs (bool) –
当指定了“f”时:如果为True,所有初始化器(通常对应于模型权重)都将添加到图的输入中。如果为False,则初始化器不会作为输入添加到图,并且只添加用户输入作为输入。
如果打算在运行时提供模型权重,请将其设置为True。如果权重是静态的,则将其设置为False,以便后端/运行时进行更好的优化(例如,常量折叠)。
您也可以将其留空,并使用返回的
torch.onnx.ONNXProgram
来控制序列化模型时如何处理初始化器。dynamic_axes (Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None) –
当
dynamo=True
且fallback
未启用时,首选指定dynamic_shapes
。默认情况下,导出的模型的所有输入和输出张量的形状都将精确匹配
args
中给出的形状。要将张量的轴指定为动态(即,仅在运行时已知),请将dynamic_axes
设置为一个具有以下模式的字典:- KEY(str):输入或输出名称。每个名称还必须在
input_names
或 output_names
.
- KEY(str):输入或输出名称。每个名称还必须在
- VALUE(dict或list):如果为dict,则键是轴索引,值是轴名称。如果为
list,则每个元素都是一个轴索引。
例如
class SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1) torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], )
产生
input { name: "x" ... shape { dim { dim_value: 2 # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_value: 2 # axis 0 ...
而
torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], dynamic_axes={ # dict value: manually named axes "x": {0: "my_custom_axis_name"}, # list value: automatic names "sum": [0], }, )
产生
input { name: "x" ... shape { dim { dim_param: "my_custom_axis_name" # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_param: "sum_dynamic_axes_1" # axis 0 ...
training (_C_onnx.TrainingMode) – 已弃用选项。而是先设置模型的训练模式,然后再导出。
operator_export_type (_C_onnx.OperatorExportTypes) – 已弃用选项。仅支持ONNX。
do_constant_folding (bool) – 已弃用选项。
export_modules_as_functions (bool | Collection[type[torch.nn.Module]]) – 已弃用选项。
autograd_inlining (bool) – 已弃用选项。
- 返回
torch.onnx.ONNXProgram
如果dynamo为True,否则为None。- 返回类型
ONNXProgram | None
版本2.6已更改: _training_现已弃用。而是先设置模型的训练模式,然后再导出。 _operator_export_type_现已弃用。仅支持ONNX。 _do_constant_folding_现已弃用。它始终启用。 _export_modules_as_functions_现已弃用。 _autograd_inlining_现已弃用。
版本2.7已更改: _optimize_现在默认设置为True。
版本2.9已更改: _dynamo_现在默认设置为True。
- class torch.onnx.ONNXProgram(model, exported_program)#
一个表示可与torch张量调用的ONNX程序的类。
- 变量
model – ONNX模型,作为ONNX IR模型对象。
exported_program – 产生ONNX模型的导出程序。
- apply_weights(state_dict)[source]#
将指定state_dict中的权重应用于ONNX模型。
使用此方法替换FakeTensors或其他权重。
- 参数
state_dict (dict[str, torch.Tensor]) – 包含要应用于ONNX模型的权重的state_dict。
- compute_values(value_names, args=(), kwargs=None)[source]#
计算ONNX模型中指定名称的值。
此方法用于计算ONNX模型中指定名称的值。这些值将作为映射名称到张量的字典返回。
- initialize_inference_session(initializer=<function _ort_session_initializer>)[source]#
初始化ONNX Runtime推理会话。
- property model_proto: ModelProto#
返回ONNX
ModelProto
对象。
- save(destination, *, include_initializers=True, keep_initializers_as_inputs=False, external_data=None)[source]#
将ONNX模型保存到指定的目标。
当
external_data
为True
或模型大于2GB时,权重将作为外部数据保存在单独的文件中。初始化器(模型权重)的序列化行为:
include_initializers=True
,keep_initializers_as_inputs=False
(默认): 初始化器包含在保存的模型中。include_initializers=True
,keep_initializers_as_inputs=True
: 初始化器包含在保存的模型中,并作为模型输入保留。选择此选项是为了在推理时能够通过提供初始化器作为模型输入来覆盖模型权重。include_initializers=False
,keep_initializers_as_inputs=False
: 初始化器不包含在保存的模型中,也不列为模型输入。选择此选项是为了在后续处理步骤中将初始化器附加到ONNX模型。include_initializers=False
,keep_initializers_as_inputs=True
: 初始化器不包含在保存的模型中,但会作为模型输入列出。选择此选项是为了在推理时提供初始化器,并最大限度地减小保存的模型的大小。
- 参数
- 引发
TypeError – 如果
external_data
为True
且destination
不是文件路径。
- class torch.onnx.OnnxExporterError#
ONNX导出器引发的错误。这是所有导出器错误的基类。