TorchScript-based ONNX Exporter#
Created On: Aug 31, 2017 | Last Updated On: Jun 10, 2025
注意
To export an ONNX model using TorchDynamo instead of TorchScript, please see Learn more about the TorchDynamo-based ONNX Exporter
Example: AlexNet from PyTorch to ONNX#
Here is a simple script which exports a pretrained AlexNet to an ONNX file named alexnet.onnx
. The call to torch.onnx.export
runs the model once to trace its execution and then exports the traced model to the specified file
import torch
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()
# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
The resulting alexnet.onnx
file contains a binary protocol buffer which contains both the network structure and parameters of the model you exported (in this case, AlexNet). The argument verbose=True
causes the exporter to print out a human-readable representation of the model
# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
%learned_0 : Float(64, 3, 11, 11)
%learned_1 : Float(64)
%learned_2 : Float(192, 64, 5, 5)
%learned_3 : Float(192)
# ---- omitted for brevity ----
%learned_14 : Float(1000, 4096)
%learned_15 : Float(1000)) {
# Every statement consists of some output tensors (and their types),
# the operator to be run (with its attributes, e.g., kernels, strides,
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
# ---- omitted for brevity ----
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
# Dynamic means that the shape is not known. This may be because of a
# limitation of our implementation (which we would like to fix in a
# future release) or shapes which are truly dynamic.
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
# ---- omitted for brevity ----
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
return (%output1);
}
You can also verify the output using the ONNX library, which you can install using pip
pip install onnx
Then, you can run
import onnx
# Load the ONNX model
model = onnx.load("alexnet.onnx")
# Check that the model is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))
You can also run the exported model with one of the many runtimes that support ONNX. For example after installing ONNX Runtime, you can load and run the model
import onnxruntime as ort
import numpy as np
ort_session = ort.InferenceSession("alexnet.onnx")
outputs = ort_session.run(
None,
{"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])
Here is a more involved tutorial on exporting a model and running it with ONNX Runtime.
Tracing vs Scripting#
Internally, torch.onnx.export()
requires a torch.jit.ScriptModule
rather than a torch.nn.Module
. If the passed-in model is not already a ScriptModule
, export()
will use tracing to convert it to one
Tracing: If
torch.onnx.export()
is called with a Module that is not already aScriptModule
, it first does the equivalent oftorch.jit.trace()
, which executes the model once with the givenargs
and records all operations that happen during that execution. This means that if your model is dynamic, e.g., changes behavior depending on input data, the exported model will not capture this dynamic behavior. We recommend examining the exported model and making sure the operators look reasonable. Tracing will unroll loops and if statements, exporting a static graph that is exactly the same as the traced run. If you want to export your model with dynamic control flow, you will need to use scripting.Scripting: Compiling a model via scripting preserves dynamic control flow and is valid for inputs of different sizes. To use scripting
Use
torch.jit.script()
to produce aScriptModule
.Call
torch.onnx.export()
with theScriptModule
as the model. Theargs
are still required, but they will be used internally only to produce example outputs, so that the types and shapes of the outputs can be captured. No tracing will be performed.
See Introduction to TorchScript and TorchScript for more details, including how to compose tracing and scripting to suit the particular requirements of different models.
Avoiding Pitfalls#
Avoid NumPy and built-in Python types#
PyTorch models can be written using NumPy or Python types and functions, but during tracing, any variables of NumPy or Python types (rather than torch.Tensor) are converted to constants, which will produce the wrong result if those values should change depending on the inputs.
For example, rather than using numpy functions on numpy.ndarrays
# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)
Use torch operators on torch.Tensors
# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)
And rather than use torch.Tensor.item()
(which converts a Tensor to a Python built-in number)
# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
return x.reshape(y.item(), -1)
Use torch’s support for implicit casting of single-element tensors
# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
return x.reshape(y, -1)
Avoid Tensor.data#
Using the Tensor.data field can produce an incorrect trace and therefore an incorrect ONNX graph. Use torch.Tensor.detach()
instead. (Work is ongoing to remove Tensor.data entirely).
Avoid in-place operations when using tensor.shape in tracing mode#
In tracing mode, shapes obtained from tensor.shape
are traced as tensors, and share the same memory. This might cause a mismatch the final output values. As a workaround, avoid the use of inplace operations in these scenarios. For example, in the model
class Model(torch.nn.Module):
def forward(self, states):
batch_size, seq_length = states.shape[:2]
real_seq_length = seq_length
real_seq_length += 2
return real_seq_length + seq_length
real_seq_length
and seq_length
share the same memory in tracing mode. This could be avoided by rewriting the inplace operation
real_seq_length = real_seq_length + 2
Limitations#
Types#
Only
torch.Tensors
, numeric types that can be trivially converted to torch.Tensors (e.g. float, int), and tuples and lists of those types are supported as model inputs or outputs. Dict and str inputs and outputs are accepted in tracing mode, butAny computation that depends on the value of a dict or a str input will be replaced with the constant value seen during the one traced execution.
任何输出为 dict 的内容都将被静默替换为其值的展平序列(键将被移除)。例如,
{"foo": 1, "bar": 2}
将会变成(1, 2)
。任何输出为 str 的内容都将被静默移除。
在脚本模式下,由于 ONNX 对嵌套序列的支持有限,因此不支持涉及元组和列表的某些操作。特别是,将元组追加到列表是不支持的。在跟踪模式下,嵌套序列将在跟踪期间自动展平。
运算符实现差异#
由于运算符实现上的差异,在不同运行时上运行导出的模型可能会产生彼此不同或与 PyTorch 不同的结果。通常这些差异在数值上很小,所以这只应在您的应用程序对这些微小差异敏感时才需要关注。
不支持的张量索引模式#
下面列出了无法导出的张量索引模式。如果您在导出模型时遇到问题,但该模型不包含以下任何不受支持的模式,请仔细检查您是否使用了最新的 opset_version
。
读取 / 获取#
在索引张量进行读取时,不支持以下模式
# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.
写入 / 设置#
在索引张量进行写入时,不支持以下模式
# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
# or multiple consecutive tensor indices with rank == 1.
# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.
# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.
# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
# data shape: [3, 4, 5]
# new_data shape: [5]
# expected new_data shape after broadcasting: [2, 2, 2, 5]
添加对运算符的支持#
导出包含不受支持的运算符的模型时,您会看到类似以下的错误消息
RuntimeError: ONNX export failed: Couldn't export operator foo
发生这种情况时,您可以采取一些措施:
更改模型以不使用该运算符。
创建一个符号函数来转换该运算符,并将其注册为自定义符号函数。
为 PyTorch 贡献代码,将相同的符号函数添加到
torch.onnx
本身。
如果您决定实现一个符号函数(我们希望您能将其贡献回 PyTorch!),以下是如何开始:
ONNX 导出器内部机制#
“符号函数”是一个将 PyTorch 运算符分解为一系列 ONNX 运算符的组合的函数。
导出期间,导出器会按照拓扑顺序访问 TorchScript 图中的每个节点(包含一个 PyTorch 运算符)。访问节点时,导出器会查找该运算符已注册的符号函数。符号函数是用 Python 实现的。名为 foo
的运算符的符号函数大致看起来像这样:
def foo(
g,
input_0: torch._C.Value,
input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
"""
Adds the ONNX operations representing this PyTorch function by updating the
graph g with `g.op()` calls.
Args:
g (Graph): graph to write the ONNX representation into.
input_0 (Value): value representing the variables which contain
the first input for this operator.
input_1 (Value): value representing the variables which contain
the second input for this operator.
Returns:
A Value or List of Values specifying the ONNX nodes that compute something
equivalent to the original PyTorch operator with the given inputs.
None if it cannot be converted to ONNX.
"""
...
torch._C
类型是 ir.h 中定义的 C++ 类型之上的 Python 包装器。
添加符号函数的过程取决于运算符的类型。
ATen 运算符#
ATen 是 PyTorch 的内置张量库。如果运算符是 ATen 运算符(在 TorchScript 图中显示为带有 aten::
前缀),请确保它尚未受支持。
支持的运算符列表#
访问自动生成的 支持的 TorchScript 运算符列表,了解每个 opset_version
支持的运算符的详细信息。
为 ATen 或量化运算符添加支持#
如果运算符不在上面的列表中
在
torch/onnx/symbolic_opset<version>.py
中定义符号函数,例如 torch/onnx/symbolic_opset9.py。确保函数名与 ATen 函数名相同,ATen 函数名可能声明在torch/_C/_VariableFunctions.pyi
或torch/nn/functional.pyi
中(这些文件在构建时生成,因此在您签出代码时不会出现,直到您构建 PyTorch)。默认情况下,第一个参数是 ONNX 图。其他参数名必须与
.pyi
文件中的名称完全匹配,因为分派是通过关键字参数完成的。在符号函数中,如果运算符属于 ONNX 标准运算符集,我们只需要创建一个节点来表示图中的 ONNX 运算符。如果不是,我们可以组合几个具有与 ATen 运算符等效语义的标准运算符。
以下是处理 ELU
运算符缺失符号函数的示例。
如果我们运行以下代码:
print(
torch.jit.trace(
torch.nn.ELU(), # module
torch.ones(1) # example input
).graph
)
我们会看到类似这样的内容:
graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%4 : float = prim::Constant[value=1.]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=1]()
%7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
return (%7)
由于我们在图中看到了 aten::elu
,我们知道这是一个 ATen 运算符。
我们检查 ONNX 运算符列表,并确认 Elu
在 ONNX 中已标准化。
我们在 torch/nn/functional.pyi
中找到 elu
的签名:
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
我们将以下行添加到 symbolic_opset9.py
:
def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
return g.op("Elu", input, alpha_f=alpha)
现在 PyTorch 能够导出包含 aten::elu
运算符的模型了!
请参阅 torch/onnx/symbolic_opset*.py
文件以获取更多示例。
torch.autograd.Functions#
如果运算符是 torch.autograd.Function
的子类,有三种方法可以导出它。
静态符号方法#
您可以在函数类中添加一个名为 symbolic
的静态方法。它应该返回表示函数在 ONNX 中行为的 ONNX 运算符。例如:
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
内联 Autograd 函数#
在未为其后续的 torch.autograd.Function
提供静态符号方法,或者没有提供将 prim::PythonOp
注册为自定义符号函数的函数的情况下,torch.onnx.export()
会尝试内联与该 torch.autograd.Function
对应的图,以便该函数被分解为函数内使用的各个运算符。只要这些单个运算符受支持,导出就应该成功。例如:
class MyLogExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
h = input.exp()
return h.log().log()
此模型没有提供静态符号方法,但它被导出如下:
graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%1 : float = onnx::Exp[](%input)
%2 : float = onnx::Log[](%1)
%3 : float = onnx::Log[](%2)
return (%3)
如果您需要避免内联 torch.autograd.Function
,您应该设置 operator_export_type
为 ONNX_FALLTHROUGH
或 ONNX_ATEN_FALLBACK
来导出模型。
自定义运算符#
您可以导出包含多个标准 ONNX 运算符组合的模型,或者由自定义 C++ 后端驱动的模型。
ONNX-script 函数#
如果某个运算符不是标准的 ONNX 运算符,但可以由多个现有 ONNX 运算符组成,您可以使用 ONNX-script 来创建一个外部 ONNX 函数来支持该运算符。您可以按照此示例导出它:
import onnxscript
# There are three opset version needed to be aligned
# This is (1) the opset version in ONNX function
from onnxscript.onnx_opset import opset15 as op
opset_version = 15
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
@onnxscript.script(custom_opset)
def Selu(X):
alpha = 1.67326 # auto wrapped as Constants
gamma = 1.0507
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
# setType API provides shape/type to ONNX shape/type inference
def custom_selu(g: jit_utils.GraphContext, X):
return g.onnxscript_op(Selu, X).setType(X.type())
# Register custom symbolic function
# There are three opset version needed to be aligned
# This is (2) the opset version in registry
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=opset_version,
)
# There are three opset version needed to be aligned
# This is (2) the opset version in exporter
torch.onnx.export(
model,
x,
"model.onnx",
opset_version=opset_version,
# only needed if you want to specify an opset version > 1.
custom_opsets={"onnx-script": 2}
)
上面的示例将其导出一个名为“onnx-script”的 opset 中的自定义运算符。导出自定义运算符时,您可以通过导出的 custom_opsets
字典指定自定义 opset 版本。如果未指定,自定义 opset 版本默认为 1。
注意:请务必匹配上述示例中提到的 opset 版本,并确保它们在导出器步骤中被使用。如何编写 onnx-script 函数的示例是其活跃开发中的一个 beta 版本。请遵循最新的 ONNX-script。
C++ 运算符#
如果模型使用 C++ 实现的自定义运算符,如 使用自定义 C++ 运算符扩展 TorchScript 中所述,您可以按照此示例导出它:
from torch.onnx import symbolic_helper
# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)
class FooModel(torch.nn.Module):
def __init__(self, attr1, attr2):
super().__init__()
self.attr1 = attr1
self.attr2 = attr2
def forward(self, input1, input2):
# Calling custom op
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
model = FooModel(attr1, attr2)
torch.onnx.export(
model,
(example_input1, example_input1),
"model.onnx",
# only needed if you want to specify an opset version > 1.
custom_opsets={"custom_domain": 2}
)
上面的示例将其导出一个名为“custom_domain”的 opset 中的自定义运算符。导出自定义运算符时,您可以通过导出的 custom_opsets
字典指定自定义域版本。如果未指定,自定义 opset 版本默认为 1。
使用该模型的运行时需要支持自定义 op。请参阅 Caffe2 自定义 ops、ONNX Runtime 自定义 ops 或您选择的运行时的文档。
一次性发现所有不可转换的 ATen 运算符#
当由于不可转换的 ATen 运算符导致导出失败时,可能不止一个这样的运算符,但错误消息只提到第一个。要一次性发现所有不可转换的运算符,您可以:
# prepare model, args, opset_version
...
torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
model, args, opset_version=opset_version
)
print(set(unconvertible_ops))
该集合是近似的,因为一些运算符在转换过程中可能会被移除,无需转换。其他一些运算符可能支持部分功能,在特定输入下可能导致转换失败,但这应该能让您对不支持的运算符有一个大致的了解。请随时在 GitHub 上提出 op 支持请求。
常见问题#
问:我导出了我的 LSTM 模型,但它的输入大小似乎是固定的?
跟踪器会记录示例输入的形状。如果模型应该接受动态形状的输入,请在调用
torch.onnx.export()
时设置dynamic_axes
。
问:如何导出包含循环的模型?
请参阅 跟踪 vs. 脚本。
问:如何导出带有原始类型输入(例如 int、float)的模型?
PyTorch 1.9 中添加了对原始数值类型输入的支持。但是,导出器不支持带有 str 输入的模型。
问:ONNX 是否支持隐式标量数据类型转换?
ONNX 标准不支持,但导出器会尝试处理这部分。标量被导出为常量张量。导出器会找出标量的正确数据类型。在极少数情况下,当它无法做到时,您需要手动指定数据类型,例如使用 dtype=torch.float32。如果您看到任何错误,请[创建一个 GitHub issue](pytorch/pytorch#issues)。
问:张量列表是否可以导出到 ONNX?
是的,对于
opset_version
>= 11,因为 ONNX 在 opset 11 中引入了 Sequence 类型。
Python API#
函数#
- torch.onnx.export(model, args=(), f=None, *, kwargs=None, export_params=True, verbose=None, input_names=None, output_names=None, opset_version=None, dynamic_axes=None, keep_initializers_as_inputs=False, dynamo=False, external_data=True, dynamic_shapes=None, custom_translation_table=None, report=False, optimize=True, verify=False, profile=False, dump_exported_program=False, artifacts_dir='.', fallback=False, training=<TrainingMode.EVAL: 0>, operator_export_type=<OperatorExportTypes.ONNX: 0>, do_constant_folding=True, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True)[source]
将模型导出为 ONNX 格式。
设置
dynamo=True
可启用新的 ONNX 导出逻辑,该逻辑基于torch.export.ExportedProgram
和一套更现代的翻译逻辑。这是导出模型到 ONNX 的推荐方式。当
dynamo=True
时:导出器尝试使用以下策略获取用于转换为 ONNX 的 ExportedProgram:
如果模型已经是 ExportedProgram,则按原样使用。
使用
torch.export.export()
并设置strict=False
。使用
torch.export.export()
并设置strict=True
。使用
draft_export
,它会移除数据相关操作中的一些健全性保证,以允许导出继续进行。如果导出器遇到任何不健全的数据相关操作,您将收到警告。使用
torch.jit.trace()
来跟踪模型,然后转换为 ExportedProgram。这是最不健全的策略,但对于将 TorchScript 模型转换为 ONNX 可能有用。
- 参数
model (torch.nn.Module | torch.export.ExportedProgram | torch.jit.ScriptModule | torch.jit.ScriptFunction) – 要导出的模型。
args (tuple[Any, ...]) – 示例位置输入。任何非张量参数都将被硬编码到导出的模型中;任何张量参数都将成为导出模型的输入,顺序与元组中出现的顺序一致。
f (str | os.PathLike | None) – 输出 ONNX 模型文件的路径。例如 “model.onnx”。
export_params (bool) – 如果为 false,则不导出参数(权重)。
verbose (bool | None) – 是否启用详细日志记录。
input_names (Sequence[str] | None) – 分配给图输入节点的名称,按顺序。
output_names (Sequence[str] | None) – 分配给图输出节点的名称,按顺序。
opset_version (int | None) – 要针对的 默认(ai.onnx)opset 的版本。必须 >= 7。
dynamic_axes (Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None) –
默认情况下,导出的模型的所有输入和输出张量的形状都将精确匹配
args
中给定的形状。要将张量的轴指定为动态(即仅在运行时知道),请将dynamic_axes
设置为具有以下架构的字典:- KEY (str): 输入或输出名称。每个名称也必须在
input_names
或 output_names
.
- KEY (str): 输入或输出名称。每个名称也必须在
- VALUE (dict 或 list): 如果是 dict,键是轴索引,值是轴名称。如果是
list,每个元素是轴索引。
例如
class SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1) torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], )
产生
input { name: "x" ... shape { dim { dim_value: 2 # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_value: 2 # axis 0 ...
同时
torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], dynamic_axes={ # dict value: manually named axes "x": {0: "my_custom_axis_name"}, # list value: automatic names "sum": [0], }, )
产生
input { name: "x" ... shape { dim { dim_param: "my_custom_axis_name" # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_param: "sum_dynamic_axes_1" # axis 0 ...
keep_initializers_as_inputs (bool) –
如果为 True,则导出图中的所有初始化器(通常对应于模型权重)也将作为输入添加到图中。如果为 False,则初始化器不作为输入添加到图中,只添加用户输入作为输入。
如果您打算在运行时提供模型权重,请将其设置为 True。如果您打算将权重设置为静态以允许后端/运行时进行更好的优化(例如常量折叠),请将其设置为 False。
dynamo (bool) – 是否使用
torch.export
ExportedProgram 而不是 TorchScript 来导出模型。external_data (bool) – 是否将模型权重保存为外部数据文件。这对于权重过大(超过 2GB)的模型是必需的。当为 False 时,权重与模型架构一起保存在 ONNX 文件中。
dynamic_shapes (dict[str, Any] | tuple[Any, ...] | list[Any] | None) – 模型输入的动态形状的字典或元组。有关更多详细信息,请参阅
torch.export.export()
。仅当 dynamo 为 True 时使用(并且是首选)。请注意,dynamic_shapes 设计用于 dynamo=True 导出模型时使用,而 dynamic_axes 用于 dynamo=False。custom_translation_table (dict[Callable, Callable | Sequence[Callable]] | None) – 用于模型中运算符的自定义分解的字典。字典的键应该是 fx Node 中的可调用目标(例如
torch.ops.aten.stft.default
),值应该是使用 ONNX Script 构建该图的函数。此选项仅在 dynamo 为 True 时有效。report (bool) – 是否为导出过程生成 markdown 报告。此选项仅在 dynamo 为 True 时有效。
optimize (bool) – 是否优化导出的模型。此选项仅在 dynamo 为 True 时有效。默认为 True。
verify (bool) – 是否使用 ONNX Runtime 验证导出的模型。此选项仅在 dynamo 为 True 时有效。
profile (bool) – 是否对导出过程进行分析。此选项仅在 dynamo 为 True 时有效。
dump_exported_program (bool) – 是否将
torch.export.ExportedProgram
转储到文件。这对于调试导出器很有用。此选项仅在 dynamo 为 True 时有效。artifacts_dir (str | os.PathLike) – 用于保存调试工件(如报告和序列化的导出程序)的目录。此选项仅在 dynamo 为 True 时有效。
fallback (bool) – 如果 dynamo 导出器失败,是否回退到 TorchScript 导出器。此选项仅在 dynamo 为 True 时有效。启用回退时,建议即使提供了 dynamic_shapes 也设置 dynamic_axes。
training (_C_onnx.TrainingMode) – 已弃用的选项。而是应在导出模型之前设置模型的训练模式。
operator_export_type (_C_onnx.OperatorExportTypes) – 已弃用的选项。仅支持 ONNX。
do_constant_folding (bool) – 已弃用的选项。
custom_opsets (Mapping[str, int] | None) –
已弃用。一个字典:
KEY (str): opset 域名称。
VALUE (int): opset 版本。
如果
model
引用了自定义 opset 但未在此字典中提及,则 opset 版本将设置为 1。只能通过此参数指示自定义 opset 域名称和版本。export_modules_as_functions (bool | Collection[type[torch.nn.Module]]) –
已弃用的选项。
启用将所有
nn.Module
前向调用导出为 ONNX 中的本地函数的标志。或者一个集合,用于指定要导出为 ONNX 中的本地函数的特定类型模块。此功能需要opset_version
>= 15,否则导出将失败。这是因为opset_version
< 15 意味着 IR 版本 < 8,这意味着不支持本地函数。模块变量将导出为函数属性。函数属性有两种类别:1. 带有注解的属性:通过 PEP 526 风格带有类型注解的类变量将导出为属性。带有注解的属性不在 ONNX 本地函数的子图中被使用,因为它们不是由 PyTorch JIT 跟踪创建的,但消费者可以使用它们来确定是否用特定的融合内核替换该函数。
2. 推断的属性:在模块的子图中使用到的变量。属性名称将带有前缀“inferred::”。这是为了与从 Python 模块注解中检索到的预定义属性区分开来。推断的属性在 ONNX 本地函数的子图中使用。
False
(默认): 将nn.Module
前向调用导出为细粒度节点。True
: 将所有nn.Module
前向调用导出为本地函数节点。- Set of type of nn.Module:导出
nn.Module
的前向调用作为局部函数节点, 仅当
nn.Module
的类型在集合中找到时。
- Set of type of nn.Module:导出
autograd_inlining (bool) – 已弃用。用于控制是否内联 autograd 函数的标志。有关更多详细信息,请参阅 pytorch/pytorch#74765。
- 返回
torch.onnx.ONNXProgram
,如果 dynamo 为 True,否则为 None。- 返回类型
ONNXProgram | None
版本 2.6 已更改:training 现已弃用。请在导出前设置模型的训练模式。operator_export_type 现已弃用。仅支持 ONNX。do_constant_folding 现已弃用。它始终启用。export_modules_as_functions 现已弃用。autograd_inlining 现已弃用。
版本 2.7 已更改:optimize 现在默认为 True。
- torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[source]#
为自定义运算符注册符号函数。
当用户为自定义/contrib 运算符注册符号时,强烈建议通过 setType API 为该运算符添加形状推断,否则在某些极端情况下导出的图可能具有错误的形状推断。setType 的一个示例是 test_operators.py 中的 test_aten_embedding_2。
有关示例用法,请参阅模块文档中的“自定义运算符”。
- torch.onnx.unregister_custom_op_symbolic(symbolic_name, opset_version)[source]#
注销
symbolic_name
。有关示例用法,请参阅模块文档中的“自定义运算符”。