注意
转到结尾处下载完整的示例代码。
ONNX 简介 || 将 PyTorch 模型导出到 ONNX || 扩展 ONNX 导出器算子支持 || 将带控制流的模型导出到 ONNX
扩展 ONNX 导出器算子支持#
创建于:2023年10月06日 | 最后更新:2025年3月05日 | 最后验证:2024年11月05日
作者: Ti-Tai Wang, Justin Chu
概述#
本教程介绍如何为不受支持的 PyTorch 算子创建 ONNX 实现,或用您自己的实现替换现有实现。
我们将涵盖需要扩展 ONNX 导出器算子支持的三种场景:
覆盖现有 PyTorch 算子的实现
使用自定义 ONNX 算子
支持自定义 PyTorch 算子
您将学到什么
如何在 ONNX 中覆盖或添加对 PyTorch 算子的支持。
如何为专门的运行时集成自定义 ONNX 算子。
如何实现自定义 PyTorch 算子并将其转换为 ONNX。
先决条件#
在开始本教程之前,请确保您已完成以下先决条件:
torch >= 2.6
目标 PyTorch 算子
在继续之前,完成 ONNX Script 教程
使用 ONNX Script 实现算子
覆盖现有 PyTorch 算子的实现#
尽管 ONNX 导出器团队尽最大努力支持所有 PyTorch 算子,但其中一些可能尚未得到支持。在本节中,我们将演示如何将不受支持的 PyTorch 算子添加到 ONNX 注册表中。
注意
实现不受支持的 PyTorch 算子的步骤与用自定义实现替换现有 PyTorch 算子的步骤相同。因为在本教程中我们实际上没有一个不受支持的 PyTorch 算子可用,所以我们将利用这一点,用自定义实现替换 torch.ops.aten.add.Tensor
的实现,就像该算子未被 ONNX 导出器实现一样。
当模型因算子不受支持而无法导出到 ONNX 时,ONNX 导出器将显示类似以下的错误消息:
No decompositions registered for [...]
错误消息指出,不受支持的 PyTorch 算子是 torch.ops.aten.add.Tensor
。该算子的类型为 <class 'torch._ops.OpOverload'>
,我们将使用这个算子作为目标来注册我们的自定义实现。
import torch
import onnxscript
# Opset 18 is the standard supported version as of PyTorch 2.6
from onnxscript import opset18 as op
# Create a model that uses the operator torch.ops.aten.add.Tensor
class Model(torch.nn.Module):
def forward(self, input_x, input_y):
return torch.ops.aten.add.Tensor(input_x, input_y)
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# All attributes must be annotated with type hints.
def custom_aten_add(self, other, alpha: float = 1.0):
if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
# To distinguish the custom implementation from the builtin one, we switch the order of the inputs
return op.Add(other, self)
x = torch.tensor([1.0])
y = torch.tensor([2.0])
# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``.
onnx_program = torch.onnx.export(
Model().eval(),
(x, y),
dynamo=True,
custom_translation_table={
torch.ops.aten.add.Tensor: custom_aten_add,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
[torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
现在让我们检查模型,并验证模型是否使用了自定义实现。
print(onnx_program.model)
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"input_x"<FLOAT,[1]>,
%"input_y"<FLOAT,[1]>
),
outputs=(
%"add"<FLOAT,[1]>
),
) {
0 | # node_add
%"add"<FLOAT,[1]> ⬅️ ::Add(%"input_y", %"input_x")
return %"add"<FLOAT,[1]>
}
转换正在使用我们的自定义实现:在节点 node_Add_0
中,input_y
现在排在第一位,而 input_x
排在第二位。
我们可以使用 ONNX Runtime 运行模型,并通过直接在输入张量上调用 torch.onnx.ONNXProgram
来验证结果。
result = onnx_program(x, y)[0]
torch.testing.assert_close(result, torch.tensor([3.0]))
使用自定义 ONNX 算子#
在这种情况下,我们使用标准的 PyTorch 算子创建一个模型,但运行时(例如 Microsoft 的 ONNX Runtime)可以为该内核提供自定义实现,从而有效替换现有实现。
在以下示例中,我们使用 ONNX Runtime 提供的 com.microsoft.Gelu
算子,它与 ONNX 规范中的 Gelu
不同。
class GeluModel(torch.nn.Module):
def forward(self, input_x):
return torch.ops.aten.gelu(input_x)
# Create a namespace for the custom operator using ONNX Script
# ``com.microsoft`` is an official ONNX Runtime namespace
microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1)
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# NOTE: All attributes must be annotated with type hints.
# The function must be scripted using the ``@onnxscript.script()`` decorator when
# using operators from custom domains. This may be improved in future versions.
from onnxscript import FLOAT
@onnxscript.script(microsoft_op)
def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT:
return microsoft_op.Gelu(self)
onnx_program = torch.onnx.export(
GeluModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.aten.gelu.default: custom_aten_gelu,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
[torch.onnx] Obtain model graph for `GeluModel()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `GeluModel()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
让我们检查模型,并验证模型使用的 op_type 是来自命名空间 com.microsoft
的 Gelu
。
print(onnx_program.model)
<
ir_version=10,
opset_imports={'com.microsoft': 1, '': 18},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"input_x"<FLOAT,[1]>
),
outputs=(
%"gelu"<FLOAT,[1]>
),
) {
0 | # n0
%"gelu"<FLOAT,[1]> ⬅️ com.microsoft::Gelu(%"input_x")
return %"gelu"<FLOAT,[1]>
}
与前面的示例类似,我们可以使用 ONNX Runtime 运行模型并验证结果。
result = onnx_program(x)[0]
torch.testing.assert_close(result, torch.ops.aten.gelu(x))
支持自定义 PyTorch 算子#
在这种情况下,该算子是用户实现并注册到 PyTorch 的算子。
在下面的示例中,我们想使用一个自定义算子,它接受一个张量输入,并返回一个输出。该算子将输入与其自身相加,并返回四舍五入后的结果。
首先,我们假设自定义算子是使用 torch.library.custom_op()
实现和注册的。您可以参考 在 Python 中创建新的自定义算子,获取有关如何创建自定义算子的详细指南。
# Define and use the operator in PyTorch
@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=())
def add_and_round_op(input: torch.Tensor) -> torch.Tensor:
return torch.round(input + input)
@add_and_round_op.register_fake
def _add_and_round_op_fake(tensor_x):
return torch.empty_like(tensor_x)
class AddAndRoundModel(torch.nn.Module):
def forward(self, input):
return add_and_round_op(input)
# Implement the custom operator in ONNX using ONNX Script
def onnx_add_and_round(input):
return op.Round(op.Add(input, input))
onnx_program = torch.onnx.export(
AddAndRoundModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
print(onnx_program)
[torch.onnx] Obtain model graph for `AddAndRoundModel()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `AddAndRoundModel()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
ONNXProgram(
model=
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"input"<FLOAT,[1]>
),
outputs=(
%"add_and_round_op"<FLOAT,[1]>
),
) {
0 | # node_Add_0
%"val_0"<FLOAT,[1]> ⬅️ ::Add(%"input", %"input")
1 | # node_add_and_round_op
%"add_and_round_op"<FLOAT,[1]> ⬅️ ::Round(%"val_0")
return %"add_and_round_op"<FLOAT,[1]>
}
,
exported_program=
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, input: "f32[1]"):
input_1 = input
# File: /var/lib/workspace/beginner_source/onnx/onnx_registry_tutorial.py:215 in forward, code: return add_and_round_op(input)
add_and_round_op: "f32[1]" = torch.ops.mylibrary.add_and_round_op.default(input_1); input_1 = None
return (add_and_round_op,)
Graph signature:
# inputs
input: USER_INPUT
# outputs
add_and_round_op: USER_OUTPUT
Range constraints: {}
)
转换正在使用我们的自定义实现,将 torch.export.ExportedProgram`
中的 torch.ops.mylibrary.add_and_round_op.default
算子转换为 ONNX 算子 Add
和 Round
。
最后我们验证结果。
结论#
恭喜!在本教程中,我们探讨了 custom_translation_table
选项,并了解了如何使用 ONNX Script 为不受支持的或现有的 PyTorch 算子创建自定义实现。
最后,我们利用 ONNX Runtime 执行模型并与 PyTorch 的结果进行比较,这让我们对在 ONNX 生态系统中处理不受支持的算子有了全面的了解。
延伸阅读#
以下列表引用了从基本示例到高级场景的教程,不一定按所列顺序排列。您可以随时直接跳到您感兴趣的特定主题,或者坐下来愉快地浏览所有内容,以了解有关 ONNX 导出器的所有知识。
脚本总运行时间: (0 分 2.831 秒)