torch.onnx#
创建于: 2025 年 6 月 10 日 | 最后更新于: 2025 年 6 月 10 日
概述#
Open Neural Network eXchange (ONNX) 是一种用于表示机器学习模型的开放标准格式。torch.onnx
模块从原生 PyTorch torch.nn.Module
模型捕获计算图,并将其转换为 ONNX 图。
导出的模型可以被许多支持 ONNX 的运行时使用,包括 Microsoft 的 ONNX Runtime。
有两种 ONNX 导出器 API 可供使用,如下所列。两者都可以通过函数 torch.onnx.export()
调用。下一个示例展示了如何导出简单模型。
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 128, 5)
def forward(self, x):
return torch.relu(self.conv1(x))
input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32)
model = MyModel()
torch.onnx.export(
model, # model to export
(input_tensor,), # inputs of the model,
"my_model.onnx", # filename of the ONNX model
input_names=["input"], # Rename inputs for the ONNX model
dynamo=True # True or False to select the exporter to use
)
下一节将介绍导出器的两个版本。
基于 TorchDynamo 的 ONNX 导出器#
基于 TorchDynamo 的 ONNX 导出器是 PyTorch 2.1 及更高版本中最新的(测试版)导出器
TorchDynamo 引擎用于连接 Python 的帧评估 API,并将其字节码动态重写为 FX 图。然后,生成的 FX 图经过完善,最终转换为 ONNX 图。
这种方法的主要优点是 FX 图是通过字节码分析捕获的,它保留了模型的动态特性,而不是使用传统的静态追踪技术。
基于 TorchScript 的 ONNX 导出器#
基于 TorchScript 的 ONNX 导出器自 PyTorch 1.2.0 起可用
通过 TorchScript(通过 torch.jit.trace()
)追踪模型并捕获静态计算图。
因此,生成的图有一些限制
它不记录任何控制流,如 if 语句或循环;
不处理
training
和eval
模式之间的细微差别;无法真正处理动态输入
为了支持静态追踪的限制,导出器还支持 TorchScript 脚本(通过 torch.jit.script()
),这增加了对数据相关控制流的支持。然而,TorchScript 本身是 Python 语言的一个子集,因此并非所有 Python 功能都受支持,例如就地操作。