torch.onnx.verification#
创建于: 2025年3月18日 | 最后更新于: 2025年6月10日
ONNX 验证模块提供了一系列用于验证 ONNX 模型正确性的工具。
- torch.onnx.verification.verify_onnx_program(onnx_program, args=None, kwargs=None, compare_intermediates=False)[source]#
通过比较 ExportedProgram 的预期值来验证 ONNX 模型。
- class torch.onnx.verification.VerificationInfo(name, max_abs_diff, max_rel_diff, abs_diff_hist, rel_diff_hist, expected_dtype, actual_dtype)#
ONNX 程序中值的验证信息。
此类包含预期值和实际值之间的最大绝对差、最大相对差以及绝对差和相对差的直方图。它还包括预期的和实际的数据类型。
直方图表示为张量元组,其中第一个张量是直方图计数,第二个张量是分箱边缘。
- 变量
name (str) – 值的名称(输出或中间值)。
max_abs_diff (float) – 预期值和实际值之间的最大绝对差。
max_rel_diff (float) – 预期值和实际值之间的最大相对差。
abs_diff_hist (tuple[torch.Tensor, torch.Tensor]) – 一个表示绝对差直方图的张量元组。第一个张量是直方图计数,第二个张量是分箱边缘。
rel_diff_hist (tuple[torch.Tensor, torch.Tensor]) – 一个表示相对差直方图的张量元组。第一个张量是直方图计数,第二个张量是分箱边缘。
expected_dtype (torch.dtype) – 预期值的数据类型。
actual_dtype (torch.dtype) – 实际值的实际数据类型。
- torch.onnx.verification.verify(model, input_args, input_kwargs=None, do_constant_folding=True, dynamic_axes=None, input_names=None, output_names=None, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, fixed_batch_size=False, use_external_data=False, additional_test_inputs=None, options=None)[source]#
验证模型导出到 ONNX 与原始 PyTorch 模型的一致性。
已弃用,版本 2.7 起:请考虑使用
torch.onnx.export(..., dynamo=True)
并使用返回的ONNXProgram
来测试 ONNX 模型。- 参数
model (_ModelType) – 请参阅
torch.onnx.export()
。input_args (_InputArgsType) – 请参阅
torch.onnx.export()
。input_kwargs (_InputKwargsType | None) – 请参阅
torch.onnx.export()
。do_constant_folding (bool) – 请参阅
torch.onnx.export()
。dynamic_axes (Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]] | None) – 请参阅
torch.onnx.export()
。input_names (Sequence[str] | None) – 请参阅
torch.onnx.export()
。output_names (Sequence[str] | None) – 请参阅
torch.onnx.export()
。training (_C_onnx.TrainingMode) – 请参阅
torch.onnx.export()
。opset_version (int | None) – 请参阅
torch.onnx.export()
。keep_initializers_as_inputs (bool) – 请参阅
torch.onnx.export()
。verbose (bool) – 请参阅
torch.onnx.export()
。fixed_batch_size (bool) – 旧参数,仅用于 RNN 测试用例。
use_external_data (bool) – 显式指定是否使用外部数据导出模型。
additional_test_inputs (Sequence[_InputArgsType] | None) – 元组列表。每个元组是一组用于测试的输入参数。目前仅支持
*args
。options (VerificationOptions | None) – 一个控制验证行为的 VerificationOptions 对象。
- 引发
AssertionError – 如果 ONNX 模型和 PyTorch 模型的输出在指定精度下不相等。
ValueError – 如果提供的参数无效。
已弃用#
以下类和函数已弃用。