注意
转到结尾下载完整的示例代码。
ONNX 简介 || 将 PyTorch 模型导出到 ONNX || 扩展 ONNX 导出器的算子支持 || 将带有控制流的模型导出到 ONNX
将带有控制流的模型导出到 ONNX#
作者: Xavier Dupré
概述#
本教程演示了在将 PyTorch 模型导出到 ONNX 时如何处理控制流逻辑。它强调了直接导出条件语句所面临的挑战,并提供了规避这些挑战的解决方案。
条件逻辑无法导出到 ONNX,除非将其重构为使用 torch.cond()
。让我们从一个实现测试的简单模型开始。
您将学到什么
如何重构模型以使用
torch.cond()
进行导出。如何将带有控制流逻辑的模型导出到 ONNX。
如何使用 ONNX 优化器优化导出的模型。
先决条件#
torch >= 2.6
import torch
定义模型#
定义了两个模型
ForwardWithControlFlowTest
:一个其 forward 方法包含 if-else 条件的模型。
ModelWithControlFlowTest
:一个将 ForwardWithControlFlowTest
作为简单 MLP 一部分的模型。这些模型使用随机输入张量进行测试,以确认它们按预期执行。
class ForwardWithControlFlowTest(torch.nn.Module):
def forward(self, x):
if x.sum():
return x * 2
return -x
class ModelWithControlFlowTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(3, 2),
torch.nn.Linear(2, 1),
ForwardWithControlFlowTest(),
)
def forward(self, x):
out = self.mlp(x)
return out
model = ModelWithControlFlowTest()
导出模型:首次尝试#
使用 torch.export.export 导出此模型失败,因为前向传播中的控制流逻辑创建了导出器无法处理的图中断。这种行为是预期的,因为未使用 torch.cond()
编写的条件逻辑是不支持的。
使用 try-except 块来捕获导出过程中预期的失败。如果导出意外成功,则会引发 AssertionError
。
x = torch.randn(3)
model(x)
try:
torch.export.export(model, (x,), strict=False)
raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
print(e)
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)). (Size-like symbols: none)
Caused by: (_export/non_strict_utils.py:1051 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
File "/var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py", line 56, in forward
if x.sum():
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
使用 torch.onnx.export()
进行 JIT 跟踪#
当使用带有 dynamo=True 参数的 torch.onnx.export()
导出模型时,导出器默认使用 JIT 跟踪。这种回退机制允许模型导出,但由于跟踪的局限性,生成的 ONNX 图可能无法忠实地表示原始模型的逻辑。
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=True)`...
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3][1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None
l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3][1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None
l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=True)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export draft_export`...
[torch.onnx] Draft Export report:
###################################################################################################
WARNING: 1 issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################
1. Data dependent error.
When exporting, we were unable to evaluate the value of `Eq(u0, 1)`.
This was encountered 1 times.
This occurred at the following user stacktrace:
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py, lineno 1773, in _wrapped_call_impl
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py, lineno 1784, in _call_impl
if x.sum():
Locals:
x: ['Tensor(shape: torch.Size([1]), stride: (1,), storage_offset: 0)']
And the following framework stacktrace:
File /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py, lineno 1360, in __torch_function__
File /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py, lineno 1407, in __torch_function__
return func(*args, **kwargs)
As a result, it was specialized to a constant (e.g. `1` in the 1st occurrence), and asserts were inserted into the graph.
Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export draft_export`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[3]>
),
outputs=(
%"mul"<FLOAT,[1]>
),
initializers=(
%"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([ 0.1817, -0.4669], requires_grad=True), name='mlp.0.bias')},
%"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.3096], requires_grad=True), name='mlp.1.bias')},
%"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.02091814, 0.06923048], [ 0.20628652, -0.1461392 ], [-0.05838434, -0.20525895]], dtype=float32), name='val_0')},
%"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.1529778 ], [-0.41561294]], dtype=float32), name='val_2')},
%"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default')}
),
) {
0 | # node_MatMul_1
%"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.020918138325214386, 0.06923048198223114], [0.20628651976585388, -0.1461392045021057], [-0.05838434025645256, -0.20525895059108734]]})
1 | # node_linear
%"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.1816701740026474, -0.4669439494609833]})
2 | # node_MatMul_3
%"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.15297779440879822], [-0.41561293601989746]]})
3 | # node_linear_1
%"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.30958300828933716]})
4 | # node_mul
%"mul"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default"{2.0})
return %"mul"<FLOAT,[1]>
}
建议的补丁:使用 torch.cond()
进行重构#
为了使控制流可导出,本教程演示了如何将 ForwardWithControlFlowTest
中的 forward 方法替换为一个使用 torch.cond`()
的重构版本。
重构详情
两个辅助函数(identity2 和 neg)代表了条件逻辑的分支:* 使用 torch.cond`()
来指定条件和两个分支以及输入参数。* 然后将更新后的 forward 方法动态分配给模型内的 ForwardWithControlFlowTest
实例。打印子模块列表以确认替换。
def new_forward(x):
def identity2(x):
return x * 2
def neg(x):
return -x
return torch.cond(x.sum() > 0, identity2, neg, (x,))
print("the list of submodules")
for name, mod in model.named_modules():
print(name, type(mod))
if isinstance(mod, ForwardWithControlFlowTest):
mod.forward = new_forward
the list of submodules
<class '__main__.ModelWithControlFlowTest'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>
让我们看看 FX 图是什么样子的。
print(torch.export.export(model, (x,), strict=False))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_mlp_0_weight: "f32[2, 3]", p_mlp_0_bias: "f32[2]", p_mlp_1_weight: "f32[1, 2]", p_mlp_1_bias: "f32[1]", x: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(x, p_mlp_0_weight, p_mlp_0_bias); x = p_mlp_0_weight = p_mlp_0_bias = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, p_mlp_1_weight, p_mlp_1_bias); linear = p_mlp_1_weight = p_mlp_1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:244 in forward, code: input = module(input)
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
# File: <eval_with_key>.25:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_,)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (linear_1,)); gt = true_graph_0 = false_graph_0 = linear_1 = None
getitem: "f32[1]" = cond[0]; cond = None
return (getitem,)
class true_graph_0(torch.nn.Module):
def forward(self, linear_1: "f32[1]"):
# File: <eval_with_key>.22:6 in forward, code: mul = l_args_3_0__1.mul(2); l_args_3_0__1 = None
mul: "f32[1]" = torch.ops.aten.mul.Tensor(linear_1, 2); linear_1 = None
return (mul,)
class false_graph_0(torch.nn.Module):
def forward(self, linear_1: "f32[1]"):
# File: <eval_with_key>.23:6 in forward, code: neg = l_args_3_0__1.neg(); l_args_3_0__1 = None
neg: "f32[1]" = torch.ops.aten.neg.default(linear_1); linear_1 = None
return (neg,)
Graph signature:
# inputs
p_mlp_0_weight: PARAMETER target='mlp.0.weight'
p_mlp_0_bias: PARAMETER target='mlp.0.bias'
p_mlp_1_weight: PARAMETER target='mlp.1.weight'
p_mlp_1_bias: PARAMETER target='mlp.1.bias'
x: USER_INPUT
# outputs
getitem: USER_OUTPUT
Range constraints: {}
让我们再次导出。
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[3]>
),
outputs=(
%"getitem"<FLOAT,[1]>
),
initializers=(
%"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([ 0.1817, -0.4669], requires_grad=True), name='mlp.0.bias')},
%"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.3096], requires_grad=True), name='mlp.1.bias')},
%"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.02091814, 0.06923048], [ 0.20628652, -0.1461392 ], [-0.05838434, -0.20525895]], dtype=float32), name='val_0')},
%"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.1529778 ], [-0.41561294]], dtype=float32), name='val_2')},
%"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')},
%"scalar_tensor_default_2"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
),
) {
0 | # node_MatMul_1
%"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.020918138325214386, 0.06923048198223114], [0.20628651976585388, -0.1461392045021057], [-0.05838434025645256, -0.20525895059108734]]})
1 | # node_linear
%"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.1816701740026474, -0.4669439494609833]})
2 | # node_MatMul_3
%"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.15297779440879822], [-0.41561293601989746]]})
3 | # node_linear_1
%"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.30958300828933716]})
4 | # node_sum_1
%"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False}
5 | # node_gt
%"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0})
6 | # node_cond__0
%"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
graph(
name=true_graph_0,
inputs=(
),
outputs=(
%"mul_true_graph_0"<FLOAT,[1]>
),
) {
0 | # node_mul
%"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0})
return %"mul_true_graph_0"<FLOAT,[1]>
}, else_branch=
graph(
name=false_graph_0,
inputs=(
),
outputs=(
%"neg_false_graph_0"<FLOAT,[1]>
),
) {
0 | # node_neg
%"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
return %"neg_false_graph_0"<FLOAT,[1]>
}}
return %"getitem"<FLOAT,[1]>
}
我们可以优化模型,并去除为捕获控制流分支而创建的模型局部函数。
onnx_program.optimize()
print(onnx_program.model)
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[3]>
),
outputs=(
%"getitem"<FLOAT,[1]>
),
initializers=(
%"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([ 0.1817, -0.4669], requires_grad=True), name='mlp.0.bias')},
%"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.3096], requires_grad=True), name='mlp.1.bias')},
%"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.02091814, 0.06923048], [ 0.20628652, -0.1461392 ], [-0.05838434, -0.20525895]], dtype=float32), name='val_0')},
%"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.1529778 ], [-0.41561294]], dtype=float32), name='val_2')},
%"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')},
%"scalar_tensor_default_2"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
),
) {
0 | # node_MatMul_1
%"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.020918138325214386, 0.06923048198223114], [0.20628651976585388, -0.1461392045021057], [-0.05838434025645256, -0.20525895059108734]]})
1 | # node_linear
%"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.1816701740026474, -0.4669439494609833]})
2 | # node_MatMul_3
%"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.15297779440879822], [-0.41561293601989746]]})
3 | # node_linear_1
%"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.30958300828933716]})
4 | # node_sum_1
%"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False}
5 | # node_gt
%"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0})
6 | # node_cond__0
%"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
graph(
name=true_graph_0,
inputs=(
),
outputs=(
%"mul_true_graph_0"<FLOAT,[1]>
),
) {
0 | # node_mul
%"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0})
return %"mul_true_graph_0"<FLOAT,[1]>
}, else_branch=
graph(
name=false_graph_0,
inputs=(
),
outputs=(
%"neg_false_graph_0"<FLOAT,[1]>
),
) {
0 | # node_neg
%"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
return %"neg_false_graph_0"<FLOAT,[1]>
}}
return %"getitem"<FLOAT,[1]>
}
结论#
本教程演示了将带有条件逻辑的模型导出到 ONNX 的挑战,并提出了使用 torch.cond()
的实用解决方案。虽然默认的导出器可能会失败或产生不完美的图,但重构模型的逻辑可以确保兼容性并生成一个忠实的 ONNX 表示。
通过理解这些技术,我们可以克服在处理 PyTorch 模型中的控制流时常见的陷阱,并确保与 ONNX 工作流的顺利集成。
延伸阅读#
下面的列表引用了从基础示例到高级场景的教程,不一定按所列顺序排列。您可以随时直接跳转到您感兴趣的特定主题,或者坐下来享受完整学习 ONNX 导出器所有内容的过程。
脚本总运行时间: (0 分 2.895 秒)