评价此页

草稿导出#

创建于:2025 年 6 月 13 日 | 最后更新于:2025 年 7 月 16 日

警告

此功能不适用于生产环境,旨在作为调试 `torch.export` 跟踪错误的工具。

草稿导出是导出功能的一个新版本,旨在稳定地生成一个图,即使存在潜在的正确性问题,并生成一个报告,列出导出在跟踪过程中遇到的所有问题,并提供额外的调试信息。对于没有假内核的自定义运算符,它还将生成一个配置文件,您可以注册该文件以自动生成假内核。

您是否曾尝试使用 torch.export.export() 导出模型,却遇到了数据依赖问题?您修复了它,但又遇到了缺少假内核的问题。解决此问题后,您又遇到了另一个数据依赖问题。您不禁想,要是有一种方法可以让我只获取一个图来调试,并能在同一处查看所有问题,这样我就可以以后再修复它们……

现在有 draft_export 来帮忙!

draft_export 是导出功能的一个版本,它将始终成功导出图,即使存在潜在的正确性问题。这些问题随后将被编译成报告,以便更清晰地可视化,以后可以进行修复。

它能捕获哪些类型的错误?#

草稿导出有助于捕获和调试以下错误:

  • 数据依赖错误上的守卫

  • 约束违反错误

  • 缺少假内核

  • 错误的假内核实现

它是如何工作的?#

在正常的导出中,我们会将示例输入转换为 `FakeTensors`,并使用它们来记录操作并将程序跟踪为图。形状可能改变的输入张量(通过 `dynamic_shapes` 标记)或张量中的值(通常来自 `.item()` 调用)将被表示为符号形状(`SymInt`),而不是具体的整数。然而,在跟踪过程中可能会发生一些问题——我们可能会遇到无法评估的守卫,例如,如果我们想检查张量中的某个项是否大于 0(`u0 >= 0`)。由于跟踪器对 `u0` 的值一无所知,它将抛出数据依赖错误。如果模型使用了自定义运算符但尚未为其定义假内核,那么我们将因 `fake_tensor.UnsupportedOperatorException` 而出错,因为导出不知道如何在 `FakeTensors` 上应用它。如果自定义运算符的假内核实现不正确,导出将默默地生成一个不匹配即时行为的不正确图。

为了修复上述错误,草稿导出使用 *真实张量跟踪* 来指导我们在跟踪时如何进行。当我们使用假张量跟踪模型时,对于在假张量上发生的每个操作,草稿导出还会对存储的真实张量(来自传递给导出的示例输入)运行该运算符。这使我们能够解决上述错误:当我们遇到无法评估的守卫(如 `u0 >= 0`)时,我们将使用存储的真实张量值来评估该守卫。运行时断言将被添加到图中,以确保图断言与我们在跟踪时所假设的守卫相同。如果我们遇到没有假内核的自定义运算符,我们将使用存储的真实张量运行该运算符的正常内核,并返回一个具有相同秩但形状未绑定的假张量。由于我们有每个操作的真实张量输出,我们将把这个与假内核的假张量输出进行比较。如果假内核实现不正确,我们将捕获此行为并生成更正确的假内核。

如何使用草稿导出?#

假设您正在尝试导出以下代码片段

class M(torch.nn.Module):
    def forward(self, x, y, z):
        res = torch.ops.mylib.foo2(x, y)

        a = res.item()
        a = -a
        a = a // 3
        a = a + 5

        z = torch.cat([z, z])

        torch._check_is_size(a)
        torch._check(a < z.shape[0])

        return z[:a]

inp = (torch.tensor(3), torch.tensor(4), torch.ones(3, 3))

ep = torch.export.export(M(), inp)

这会因为 `mylib.foo2` 的“缺少假内核”错误以及由于 `z` 使用 `a`(一个未绑定的 `symint`)进行切片而导致的 `GuardOnDataDependentExpression` 错误。

要调用 `draft-export`,我们可以将 `torch.export` 行替换为以下内容:

ep = torch.export.draft_export(M(), inp)

ep 是一个有效的 `ExportedProgram`,现在可以传递给进一步的环境!

使用草稿导出进行调试#

在草稿导出的终端输出中,您应该会看到以下消息:

#########################################################################################
WARNING: 2 issue(s) found during export, and it was not able to soundly produce a graph.
To view the report of failures in an html page, please run the command:
    `tlparse /tmp/export_angelayi/dedicated_log_torch_trace_axpofwe2.log --export`
Or, you can view the errors in python by inspecting `print(ep._report)`.
########################################################################################

草稿导出会自动转储 `tlparse` 的日志。您可以使用 `print(ep._report)` 查看跟踪错误,或者将日志传递给 `tlparse` 来生成 HTML 报告。

在终端中运行 `tlparse` 命令将生成一个 tlparse HTML 报告。这是一个 `tlparse` 报告的示例:

../_images/draft_export_report.png

点击“数据依赖错误”,我们将看到以下页面,其中包含帮助调试此错误的信息。具体来说,它包含:

  • 发生此错误的堆栈跟踪

  • 局部变量及其形状列表

  • 有关此守卫如何创建的信息

../_images/draft_export_report_dde.png

返回的 `ExportedProgram`#

由于草稿导出根据示例输入对代码路径进行专门化,因此从草稿导出返回的 `ExportedProgram` **至少**保证对于给定的示例输入是可运行的并返回正确的结果。其他输入也可以工作,只要它们匹配草稿导出时所采取的守卫。

例如,如果我们有一个基于值是否大于 5 的图分支,如果在草稿导出中我们的示例输入大于 5,那么返回的 `ExportedProgram` 将专门化该分支,并断言该值大于 5。这意味着如果您传入另一个大于 5 的值,程序将成功,但如果您传入一个小于 5 的值,程序将失败。这比 `torch.jit.trace` 更安全,后者会默默地专门化分支。`torch.export` 支持两个分支的正确方法是使用 `torch.cond` 重写代码,它将捕获两个分支。

由于图中的运行时断言,返回的 `exported-program` 也可以使用 `torch.export` 或 `torch.compile` 进行重新跟踪,并且在自定义运算符缺少假内核的情况下需要进行少量额外操作。

生成假内核#

如果自定义运算符不包含假实现,目前草稿导出将使用真实张量传播来获取运算符的输出并继续跟踪。然而,如果我们使用假张量运行导出的程序或重新跟踪导出的模型,我们仍然会失败,因为仍然没有假内核实现。

为了解决这个问题,在草稿导出之后,我们将为遇到的每个自定义运算符调用生成一个运算符配置文件,并将其存储在附加到导出的程序的报告中:`ep._report.op_profiles`。然后,用户可以使用上下文管理器 `torch._library.fake_profile.unsafe_generate_fake_kernels` 基于这些运算符配置文件生成和注册一个假实现。这样,未来的假张量重新跟踪就能正常工作。

工作流程可能如下所示:

class M(torch.nn.Module):
    def forward(self, a, b):
        res = torch.ops.mylib.foo(a, b)  # no fake impl
        return res

ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4)))

with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles):
    decomp = ep.run_decompositions()

new_inp = (
    torch.ones(2, 3, 4),
    torch.ones(2, 3, 4),
)

# Save the profile to a yaml and check it into a codebase
save_op_profiles(ep._report.op_profiles, "op_profile.yaml")
# Load the yaml
loaded_op_profile = load_op_profiles("op_profile.yaml")

运算符配置文件是一个字典,将运算符名称映射到一组配置文件,这些配置文件描述了运算符的输入和输出,并且可以手动编写、保存到 yaml 文件并提交到代码库。下面是一个 `mylib.foo.default` 的配置文件的示例:

"mylib.foo.default": {
    OpProfile(
        args_profile=(
            TensorMetadata(
                rank=2,
                dtype=torch.float32,
                device=torch.device("cpu"),
                layout=torch.strided,
            ),
            TensorMetadata(
                rank=2,
                dtype=torch.float32,
                device=torch.device("cpu"),
                layout=torch.strided,
            ),
        ),
        out_profile=TensorMetadata(
            rank=2,
            dtype=torch.float32,
            device=torch.device("cpu"),
            layout=torch.strided,
        ),
    )
}

`mylib.foo.default` 的配置文件只包含一个配置文件,它表示对于 2 个秩为 2、`dtype` 为 `torch.float32`、设备为 `cpu` 的输入张量,我们将返回一个秩为 2、`dtype` 为 `torch.float32`、设备为 `cpu` 的输出张量。使用上下文管理器,将生成一个假内核,当给定 2 个秩为 2 的输入张量(以及其他张量元数据)时,它将输出一个秩为 2 的张量(以及其他张量元数据)。

如果该运算符还支持其他秩,那么我们可以将该配置文件添加到此配置文件的列表中,方法是手动将其添加到现有配置文件中,或使用新输入重新运行草稿导出以获取新配置文件,这样生成的假内核将支持更多输入类型。否则,它将出错。

下一步去哪里?#

既然我们已经成功使用草稿导出创建了一个 `ExportedProgram`,我们可以使用像 `AOTInductor` 这样的进一步编译器来优化其性能并生成可运行的制品。这个优化版本可以用于部署。同时,我们可以利用草稿导出生成的报告来识别和修复遇到的 `torch.export` 错误,以便原始模型可以直接使用 `torch.export` 进行跟踪。