草稿导出#
创建于: 2025年6月13日 | 最后更新于: 2025年6月13日
警告
此功能不适用于生产环境,而是作为调试 torch.export 跟踪错误的工具。
草稿导出是导出功能的新版本,旨在始终生成图,即使可能存在健全性问题,并生成一个报告,列出导出在跟踪过程中遇到的所有问题并提供其他调试信息。对于没有假内核的自定义运算符,它还将生成一个配置文件,您可以注册该文件以自动生成假内核。
您是否曾经尝试使用 torch.export.export()
导出模型,结果却遇到与数据相关的错误?修复了该问题,但又遇到了缺少假内核的问题。解决该问题后,您又遇到了另一个与数据相关的错误。您可能会想,我真希望有一种方法可以让我得到一个图来玩,并且能够在一个地方看到所有问题,以便以后修复它们……
draft_export
来拯救你!
draft_export
是导出功能的一个版本,它将始终成功导出图,即使可能存在健全性问题。然后,这些问题将被编译成报告,以便更清晰地可视化,并可以稍后修复。
它是如何工作的?#
在正常导出中,我们会将示例输入转换为 FakeTensors,并使用它们来记录操作并将程序跟踪到图中。可变形状的输入张量(通过 dynamic_shapes
标记)或张量中的值(通常来自 .item() 调用)将表示为符号形状 (SymInt
),而不是具体整数。但是,在跟踪过程中可能会出现一些问题 - 我们可能会遇到无法评估的守卫,例如,如果我们想检查张量中的某个项是否大于 0 (u0 >= 0
)。由于跟踪器对 u0
的值一无所知,它将抛出与数据相关的错误。如果模型使用了自定义运算符但没有为它定义假内核,那么我们将因 fake_tensor.UnsupportedOperatorException
而出错,因为导出不知道如何将其应用于 FakeTensors
。如果自定义运算符的假内核实现不正确,导出将默默地生成一个与急切行为不匹配的错误图。
为了修复上述错误,草稿导出使用 *真实张量跟踪* 来指导我们如何继续跟踪。当我们使用假张量跟踪模型时,对于在假张量上发生的每个操作,草稿导出还将使用存储的真实张量来运行该运算符,这些真实张量来自传递给导出的示例输入。这使我们能够解决上述错误:当我们遇到无法评估的守卫时,例如 u0 >= 0
,我们将使用存储的真实张量值来评估该守卫。运行时断言将被添加到图中,以确保图断言与我们在跟踪时假设的相同守卫。如果我们遇到一个没有假内核的自定义运算符,我们将使用存储的真实张量来运行该运算符的正常内核,并返回一个具有相同等级但形状未绑定的假张量。由于我们拥有每个操作的真实张量输出,我们将此与假内核的假张量输出进行比较。如果假内核实现不正确,我们将捕获此行为并生成一个更正确的假内核。
如何使用草稿导出?#
假设您正在尝试导出以下代码:
class M(torch.nn.Module):
def forward(self, x, y, z):
res = torch.ops.mylib.foo2(x, y)
a = res.item()
a = -a
a = a // 3
a = a + 5
z = torch.cat([z, z])
torch._check_is_size(a)
torch._check(a < z.shape[0])
return z[:a]
inp = (torch.tensor(3), torch.tensor(4), torch.ones(3, 3))
ep = torch.export.export(M(), inp)
这会导致 mylib.foo2
的“缺少假内核”错误,然后是由于使用未绑定的符号整数 a
对 z
进行切片而导致的 GuardOnDataDependentExpression
。
要调用 draft-export
,我们可以将 torch.export
行替换为以下内容:
ep = torch.export.draft_export(M(), inp)
ep
是一个有效的 ExportedProgram,现在可以传递给其他环境!
使用草稿导出进行调试#
在草稿导出的终端输出中,您应该会看到以下消息:
#########################################################################################
WARNING: 2 issue(s) found during export, and it was not able to soundly produce a graph.
To view the report of failures in an html page, please run the command:
`tlparse /tmp/export_angelayi/dedicated_log_torch_trace_axpofwe2.log --export`
Or, you can view the errors in python by inspecting `print(ep._report)`.
########################################################################################
草稿导出会自动转储 tlparse
的日志。您可以通过 print(ep._report)
查看跟踪错误,或者将日志传递给 tlparse
以生成 HTML 报告。
在终端运行 tlparse
命令将生成一个 tlparse HTML 报告。下面是一个 tlparse
报告的示例。

点击“与数据相关的错误”,我们将看到以下页面,其中包含帮助调试此错误的信息。具体来说,它包含:
发生此错误的堆栈跟踪
本地变量及其形状列表
关于此守卫如何创建的信息

返回的 Exported Program#
由于草稿导出基于示例输入对代码路径进行特化,因此通过草稿导出生成的 Exported Program 保证至少对于给定的示例输入是可运行的并且能返回正确的结果。其他输入也可以工作,只要它们匹配我们草稿导出时所采取的相同守卫。
例如,如果我们有一个基于值大于 5 的图分支,如果在草稿导出中我们的示例输入大于 5,那么返回的 ExportedProgram
将特化到该分支,并断言该值大于 5。这意味着程序将在您传入另一个大于 5 的值时成功,但在您传入小于 5 的值时会失败。这比 torch.jit.trace
更可靠,后者会默默地特化到该分支。对于 torch.export
支持两个分支的正确方法是使用 torch.cond
重写代码,该代码将捕获两个分支。
由于图中存在运行时断言,因此返回的 exported-program 也可以使用 torch.export
或 torch.compile
进行重跟踪,在缺少假内核的自定义运算符的情况下需要进行少量添加。
生成假内核#
如果自定义运算符不包含假实现,目前草稿导出将使用真实张量传播来获取运算符的输出并继续跟踪。但是,如果我们使用假张量运行导出的程序或重跟踪导出的模型,我们仍然会失败,因为仍然没有假内核实现。
为了解决这个问题,在草稿导出之后,我们将为遇到的每个自定义运算符调用生成一个运算符配置文件,并将其存储在附加到导出的程序的报告中:ep._report.op_profiles
。然后,用户可以使用上下文管理器 torch._library.fake_profile.unsafe_generate_fake_kernels
根据这些运算符配置文件生成并注册一个假实现。这样,未来的假张量重跟踪将起作用。
工作流程如下:
class M(torch.nn.Module):
def forward(self, a, b):
res = torch.ops.mylib.foo(a, b) # no fake impl
return res
ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4)))
with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles):
decomp = ep.run_decompositions()
new_inp = (
torch.ones(2, 3, 4),
torch.ones(2, 3, 4),
)
# Save the profile to a yaml and check it into a codebase
save_op_profiles(ep._report.op_profiles, "op_profile.yaml")
# Load the yaml
loaded_op_profile = load_op_profiles("op_profile.yaml")
运算符配置文件是一个将运算符名称映射到一组配置文件的字典,这些配置文件描述了运算符的输入和输出,并且可以手动编写、保存到 yaml 文件中,然后检入代码库。以下是 mylib.foo.default
的配置文件示例:
"mylib.foo.default": {
OpProfile(
args_profile=(
TensorMetadata(
rank=2,
dtype=torch.float32,
device=torch.device("cpu"),
layout=torch.strided,
),
TensorMetadata(
rank=2,
dtype=torch.float32,
device=torch.device("cpu"),
layout=torch.strided,
),
),
out_profile=TensorMetadata(
rank=2,
dtype=torch.float32,
device=torch.device("cpu"),
layout=torch.strided,
),
)
}
mylib.foo.default
的配置文件仅包含一个配置,该配置说明对于两个秩为 2、数据类型为 torch.float32
、设备为 cpu
的输入张量,我们将返回一个秩为 2、数据类型为 torch.float32
、设备为 cpu
的输出张量。使用上下文管理器,然后将生成一个伪造的内核,该内核接收两个秩为 2 的输入张量(以及其他张量元数据),并输出一个秩为 2 的张量(以及其他张量元数据)。
如果该算子还支持其他输入秩,我们可以将该配置添加到此配置列表中,方法是手动添加到现有配置中,或者使用新的输入重新运行 draft-export 以获取新配置,这样生成的伪造内核将支持更多输入类型。否则将出错。
接下来做什么?#
既然我们已经使用 draft-export 成功创建了 ExportedProgram
,我们可以使用 AOTInductor
等进一步的编译器来优化其性能并生成可运行的产物。然后,此优化版本可用于部署。同时,我们可以利用 draft-export 生成的报告来识别和修复遇到的 torch.export
错误,以便使用 torch.export
直接跟踪原始模型。