评价此页

PyTorch 2.0 疑难解答 (旧)#

创建于:2025年6月6日 | 最后更新于:2025年8月2日

作者: Michael Lazos

注意

本文档已过时,现在主要作为运行 torch.compile 最小化器(minifier)的参考资源。请参阅 更新的疑难解答文档。此外,还有一份更 全面的 torch.compile 手册 可供参考。

我们正在积极开发调试工具、性能剖析器,并改进我们的错误和警告信息。下表列出了可用的工具及其典型用法。如需进一步帮助,请参阅 诊断运行时错误

标题#

工具

目的

用法

信息日志

查看编译的摘要步骤

torch._logging.set_logs(dynamo = logging.INFO)TORCH_LOGS="dynamo"

调试日志

查看编译的详细步骤(打印跟踪的每个指令)

torch._logging.set_logs(dynamo = logging.DEBUG)torch._dynamo.config.verbose = True,或 TORCH_LOGS="+dynamo" TORCHDYNAMO_VERBOSE=1

适用于任何后端的最小化器

查找可重现任何后端错误的最短子图

设置环境变量 TORCHDYNAMO_REPRO_AFTER="dynamo"

适用于 TorchInductor 的最小化器

如果错误已知发生在 AOTAutograd 之后,则查找在 TorchInductor 降低(lowering)过程中可重现错误的最小子图

设置环境变量 TORCHDYNAMO_REPRO_AFTER="aot"

Dynamo 精度最小化器

查找在您怀疑问题出在 AOTAutograd 时,可重现 eager 模式模型和优化模型之间精度问题的最短子图

TORCHDYNAMO_REPRO_AFTER="dynamo" TORCHDYNAMO_REPRO_LEVEL=4

Inductor 精度最小化器

查找在您怀疑问题出在后端(例如 Inductor)时,可重现 eager 模式模型和优化模型之间精度问题的最短子图。如果此方法无效,请尝试使用 Dynamo 精度最小化器。

TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4

torch._dynamo.explain

查找图中断(graph breaks)并显示其原因

torch._dynamo.explain(fn)(*inputs)

录制/重放

录制和重放帧,以重现图捕获期间的错误

torch._dynamo.config.replay_record_enabled = True

TorchDynamo 函数名过滤

仅编译具有给定名称的函数,以减少调试问题时的噪音

设置环境变量 TORCHDYNAMO_DEBUG_FUNCTION=<name>

TorchInductor 调试日志

打印通用的 TorchInductor 调试信息以及生成的 Triton/C++ 代码

torch._inductor.config.debug = True

TorchInductor 跟踪

显示每个 TorchInductor 阶段花费的时间 + 输出代码和图可视化

设置环境变量 TORCH_COMPILE_DEBUG=1 或 torch._inductor.config.trace.enabled = True

除了信息和调试日志外,您还可以使用 torch._logging 进行更细粒度的日志记录。

诊断运行时错误#

从宏观上看,TorchDynamo 栈由 Python 代码的图捕获(TorchDynamo)和一个后端编译器组成。例如,后端编译器可能由反向图跟踪(AOTAutograd)和图降低(TorchInductor)组成*。错误可能发生在栈的任何组件中,并会提供完整的堆栈跟踪。

要确定错误发生在哪个组件,您可以使用信息级别的日志记录 torch._logging.set_logs(dynamo = logging.INFO)TORCH_LOGS="dynamo",并查找 Step #: ... 输出。日志记录在每个步骤的开始和结束时生成,因此一个错误应对应的步骤是最近记录的、但尚未记录其结束的那个步骤。这些步骤对应于栈的以下部分:

步骤

组件

1

TorchDynamo

2

编译器后端

3

TorchInductor

如果信息日志不足,您可以使用可用的后端选项。这些选项包括:

  • "eager":仅执行 TorchDynamo 前向图捕获,然后使用 PyTorch 运行捕获的图。这表明 TorchDynamo 是否引发了错误。

  • "aot_eager":执行 TorchDynamo 捕获前向图,然后执行 AOTAutograd 跟踪反向图,而不执行任何额外的后端编译器步骤。然后使用 PyTorch eager 运行前向和反向图。这有助于将问题缩小到 AOTAutograd。

缩小问题的通用步骤如下:

  1. 使用 "eager" 后端运行您的程序。如果错误不再发生,则问题出在使用的后端编译器中(如果使用 TorchInductor,请继续步骤 2。如果不是,请参阅 最小化后端编译器错误)。如果使用 "eager" 后端时错误仍然发生,则问题是 Torchdynamo 错误

  2. 此步骤仅在使用 TorchInductor 作为后端编译器时才需要。使用 "aot_eager" 后端运行模型。如果此后端引发错误,则错误发生在 AOTAutograd 跟踪过程中。如果使用此后端时错误不再发生,则问题是 最小化 TorchInductor 错误

这些情况中的每一种都在接下来的部分中进行了分析。

注意

TorchInductor 后端包括 AOTAutograd 跟踪和 TorchInductor 编译器本身。我们将通过将 TorchInductor 称为后端,将 TorchInductor 降低称为 AOTAutograd 跟踪的图降低的阶段来区分它们。

Torchdynamo 错误#

如果生成的错误在使用 "eager" 后端时发生,则 TorchDynamo 极有可能是错误的来源。下面是一个会产生错误的示例代码。

import torch

import torch._dynamo as dynamo


def test_assertion_error():
    y = torch.ones(200, 200)
    z = {y: 5}
    return z

compiled_test_assertion_error = torch.compile(test_assertion_error, backend="eager")

compiled_test_assertion_error()

上面的代码会产生以下错误:

torch._dynamo.convert_frame: [ERROR] WON'T CONVERT test_assertion_error /scratch/mlazos/torchdynamo/../test/errors.py line 26
due to:
Traceback (most recent call last):
  File "/scratch/mlazos/torchdynamo/torchdynamo/symbolic_convert.py", line 837, in BUILD_MAP
    assert isinstance(k, ConstantVariable) or (
AssertionError

from user code:
   File "/scratch/mlazos/torchdynamo/../test/errors.py", line 34, in test_assertion_error
    z = {y: 5}

Set torch._dynamo.config.verbose=True for more information
==========

如消息所示,您可以设置 torch._dynamo.config.verbose=True 来获取完整的堆栈跟踪,以了解 TorchDynamo 和用户代码中的错误。除了此标志,您还可以通过 torch._logging.set_logs(dynamo = logging.INFO)TORCH_LOGS="dynamo" 来设置 TorchDynamo 的 log_level。这些级别包括:

  • logging.DEBUGTORCH_LOGS="+dynamo":打印遇到的每个指令,以及下面列出的所有日志级别。

  • logging.INFO:打印每个被编译的函数(原始和修改后的字节码)以及捕获的图,以及下面列出的所有日志级别。

  • logging.WARNING (默认):打印图中断,以及下面列出的所有日志级别。

  • logging.ERROR:仅打印错误。

如果模型非常大,日志可能会变得难以处理。如果错误发生在模型 Python 代码的深层,那么执行仅发生错误帧以方便调试可能会很有用。有两种可用工具可以实现这一点:

  • 设置环境变量 TORCHDYNAMO_DEBUG_FUNCTION 为所需的函数名,将仅对具有该名称的函数运行 torchdynamo。

  • 启用录制/重放工具(设置 torch._dynamo.config.replay_record_enabled = True),该工具会在遇到错误时转储执行记录。然后可以重放此记录以运行发生错误时仅有的帧。

诊断 TorchInductor 错误#

如果错误在使用 "eager" 后端时未发生,则后端编译器是错误的来源(示例错误)。TorchDynamo 有 不同的后端编译器选择,TorchInductor 满足大多数用户的需求。本节以 TorchInductor 为例,但有些工具也可用于其他后端编译器。

以下是我们正在关注的栈的一部分:

使用 TorchInductor 作为选定的后端时,AOTAutograd 用于从 torchdynamo 捕获的前向图生成反向图。需要注意的是,错误可能发生在跟踪过程中,也可能发生在 TorchInductor 将前向和反向图降低(lower)为 GPU 代码或 C++ 代码时。一个模型通常包含数百甚至数千个 FX 节点,因此缩小导致此问题的确切节点可能非常困难。幸运的是,有一些工具可以自动最小化这些输入图,以聚焦到导致问题的节点。第一步是确定错误是发生在 AOTAutograd 的反向图跟踪过程中,还是发生在 TorchInductor 降低过程中。如上面第 2 步所述,"aot_eager" 后端可以用于单独运行 AOTAutograd,而不进行降低。如果使用此后端时错误仍然发生,这表明错误发生在 AOTAutograd 跟踪过程中。

这是一个例子:

import torch

import torch._dynamo as dynamo

model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)])

def test_backend_error():

    y = torch.ones(200, 200)
    x = torch.ones(200, 200)
    z = x + y
    a = torch.ops.aten._foobar(z)  # dummy function which errors
    return model(a)


compiled_test_backend_error = torch.compile(test_backend_error, backend="inductor")
compiled_test_backend_error()

运行此代码应会产生此错误,并在其下方显示更长的堆栈跟踪:

Traceback (most recent call last):
  File "/scratch/mlazos/torchdynamo/torchinductor/graph.py", line 246, in call_function
    return lowerings[target](*args, **kwargs)
  File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 185, in wrapped
    return decomp_fn(*args, **kwargs)
  File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 810, in _foobar
    assert False
AssertionError
...

带完整堆栈跟踪的错误

如果您然后将 torch.compile(backend="inductor") 更改为 torch.compile(backend="aot_eager"),它将无错误地运行,因为 问题 出在 TorchInductor 降低过程中,而不是 AOTAutograd 中。

最小化 TorchInductor 错误#

从这里开始,我们将运行最小化器以获得最小的可重现示例。设置环境变量 TORCHDYNAMO_REPRO_AFTER="aot"(或直接设置 torch._dynamo.config.repro_after="aot")将生成一个 Python 程序,该程序将 AOTAutograd 生成的图缩减到可重现错误的最小子图。(下面是一个我们最小化 TorchDynamo 生成的图的示例)使用此环境变量运行程序应该会显示几乎 相同的输出,并附带一行指示 minifier_launcher.py 已写入的位置。输出目录可通过将 torch._dynamo.config.base_dir 设置为有效目录名来配置。最后一步是运行最小化器并检查其是否成功运行。成功运行的外观如下: 。如果最小化器成功运行,它将生成可运行的 Python 代码,该代码可重现确切的错误。对于我们的示例,这是以下代码:

import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

# torch version: 1.13.0a0+gitfddfc44
# torch cuda version: 11.6
# torch git version: fddfc4488afb207971c54ad4bf58130fdc8a4dc5


# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0

# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8

from torch.nn import *

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, add):
        _foobar = torch.ops.aten._foobar.default(add);  add = None
        return (_foobar,)

args = [((200, 200), (200, 1), torch.float32, 'cpu')]
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]
mod = make_fx(Repro())(*args)
from torch._inductor.compile_fx import compile_fx_inner

compiled = compile_fx_inner(mod, args)
compiled(*args)

名为 Repro 模块的 forward 方法包含导致问题的确切操作。在提交 issue 时,请包含任何最小化的重现示例以帮助调试。

最小化后端编译器错误#

对于 TorchInductor 以外的后端编译器,查找导致错误的子图的过程与 最小化 TorchInductor 错误 中的过程几乎相同,但有一个重要的注意事项。即,最小化器将在 TorchDynamo 跟踪的图上运行,而不是在 AOTAutograd 的输出图上运行。让我们通过一个示例来演示。

import torch

import torch._dynamo as dynamo

model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)])
# toy compiler which fails if graph contains relu
def toy_compiler(gm: torch.fx.GraphModule, _):
    for node in gm.graph.nodes:
        if node.target == torch.relu:
            assert False

    return gm


def test_backend_error():
    y = torch.ones(200, 200)
    x = torch.ones(200, 200)
    z = x + y
    a = torch.relu(z)
    return model(a)


compiled_test_backend_error = torch.compile(test_backend_error, backend=toy_compiler)
compiled_test_backend_error()

为了在 TorchDynamo 跟踪了前向图之后运行代码,您可以使用 TORCHDYNAMO_REPRO_AFTER 环境变量。使用 TORCHDYNAMO_REPRO_AFTER="dynamo"(或 torch._dynamo.config.repro_after="dynamo")运行此程序应生成 此输出,并在 {torch._dynamo.config.base_dir}/repro.py 中生成以下代码。

注意

TORCHDYNAMO_REPRO_AFTER 的另一个选项是 "aot",它将在生成反向图后运行最小化器。

import torch
import torch._dynamo as dynamo
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch._dynamo.debug_utils import run_fwd_maybe_bwd

from torch.nn import *

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, add):
        relu = torch.relu(add);  add = None
        return (relu,)


mod = Repro().cuda()
opt_mod = torch.compile(mod, backend="None")


args = [((200, 200), (200, 1), torch.float32, 'cpu', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]


with torch.cuda.amp.autocast(enabled=False):
    ref = run_fwd_maybe_bwd(mod, args)
    res = run_fwd_maybe_bwd(opt_mod, args)

最小化器已成功将图缩减到导致 toy_compiler 中错误的那个操作。与 最小化 TorchInductor 错误 中的过程的另一个区别是,在遇到后端编译器错误后,最小化器会自动运行。成功运行后,最小化器会将 repro.py 写入 torch._dynamo.config.base_dir

性能剖析#

访问 TorchDynamo 性能剖析器#

TorchDynamo 具有内置的统计函数,用于收集和显示每个编译阶段花费的时间。这些统计信息可以通过在执行 Torch._Dynamo 后调用 torch._dynamo.utils.compile_times() 来访问。默认情况下,它会返回按名称表示的每个 TorchDynamo 函数的编译时间字符串。

使用 TORCH_COMPILE_DEBUG 进行 TorchInductor 调试#

TorchInductor 具有内置的统计和跟踪函数,用于显示每个编译阶段花费的时间、输出代码、输出图可视化和 IR 转储。这是一个旨在简化 TorchInductor 内部的理解和故障排除的调试工具。

让我们使用以下测试程序(repro.py)来运行一个示例:

import torch

@torch.compile()
def test_model(x):
    model = torch.nn.Sequential(
        torch.nn.Linear(10, 10),
        torch.nn.LayerNorm(10),
        torch.nn.ReLU(),
    )
    return model(x)


y = test_model(torch.ones(10, 10))

设置环境变量 TORCH_COMPILE_DEBUG=1 将创建一个调试跟踪目录。默认情况下,该目录将在当前目录中,名为 torch_compile_debug(这可以通过 torchdynamo 的配置字段 debug_dir_root 和环境变量 TORCH_COMPILE_DEBUG_DIR 来覆盖)。在此目录中,每次运行都会有一个名为时间戳和进程 ID 的独立文件夹。

$ env TORCH_COMPILE_DEBUG=1 python repro.py
$ cd torch_compile_debug
$ ls
run_2023_03_01_08_20_52_143510-pid_180167

在运行文件夹中,会有一个 torchdynamo 目录,其中包含调试日志,以及一个 torchinductor 文件夹,其中包含每个编译内核的子文件夹,以及 Inductor 调试的伪像。

$ cd
run_2023_03_01_08_20_52_143510-pid_180167
$ ls
torchinductor  torchdynamo

进一步进入 torchinductor 目录,\*.log 文件是 AOT Autograd 编译阶段的日志,model__0_forward_1.0 包含 Inductor 调试的伪像。

$ cd torchinductor
$ ls
aot_model___0_debug.log  model__0_forward_1.0
$ cd model__0_forward_1.0
$ ls
debug.log  fx_graph_readable.py  fx_graph_runnable.py  fx_graph_transformed.py  ir_post_fusion.txt  ir_pre_fusion.txt  output_code.py

以下是内容摘要:

  • fx_graph_readable.pyfx_graph_runnable.py 是 inductor 接收的 fx_graph 的可读和可运行版本。

  • fx_graph_transformed.py 是 inductor 运行所有 fx 传递后的 fx 图。

  • ir\*.txt 是 fusion 前后的 inductor ir。

  • output_code.py 是子图的已编译 triton 内核。

这是测试程序的 示例调试目录内容

import torch

@torch.compile()
def test_model(x):
    model = torch.nn.Sequential(
        torch.nn.Linear(10, 10),
        torch.nn.LayerNorm(10),
        torch.nn.ReLU(),
    )
    return model(x)


y = test_model(torch.ones(10, 10))

此新调试格式中的每个文件都可以通过 torch._inductor.config.trace.* 进行启用和禁用。由于生成性能剖析图和图示的成本很高,因此它们默认是禁用的。

此新调试格式中的一个节点如下所示:

buf1: SchedulerNode(ComputedBuffer)
buf1.writes =
    {   MemoryDep(name='buf1', index=0, size=()),
        MemoryDep(name='buf1', index=0, size=(s0,))}
buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))}
buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))}
buf1.group.device = cuda:0
buf1.group.iteration = (1, s0)
buf1.sizes = ([], [s0])
class buf1_loop_body:
    var_ranges = {z0: s0}
    index0 = z0
    index1 = 0
    def body(self, ops):
        get_index = self.get_index('index0')
        load = ops.load('buf0', get_index, False)
        get_index_1 = self.get_index('index0')
        load_1 = ops.load('primals_2', get_index_1, False)
        add = ops.add(load, load_1)
        get_index_2 = self.get_index('index1')
        reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add)
        return reduction

有关更多示例,请参阅 示例调试目录输出

图中断#

给定如下程序:

def some_fun(x):
    ...

compiled_fun = torch.compile(some_fun, ...)
...

TorchDynamo 将尝试将 some_fun 内的所有 torch/tensor 操作编译成单个 FX 图,但它可能无法将所有内容捕获到一个图中。

有些图中断的原因对于 TorchDynamo 来说是无法克服的,并且不容易修复。- 调用 C 扩展(除 torch 之外)对 torchdynamo 是不可见的,并且可以执行任意操作,而 TorchDynamo 无法引入必要的 guard(请参阅 使 Dynamo Sound:Guards)来确保编译后的程序可以安全重用。图中断会阻碍性能,如果产生的碎片很小。为了最大化性能,尽可能少地出现图中断非常重要。

识别图中断的原因#

为了识别程序中的所有图中断及其原因,可以使用 torch._dynamo.explain。此工具会在提供的函数上运行 TorchDynamo,并汇总遇到的图中断。以下是示例用法:

import torch
import torch._dynamo as dynamo
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    print("woo")
    if b.sum() < 0:
        b = b * -1
    return x * b
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation_verbose)
"""
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
  Break Reason 1:
    Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
    User Stack:
      <FrameSummary file foo.py, line 5 in toy_example>
  Break Reason 2:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
  ...
Out Guards:
  ...
"""

输出包括:

  • out_guards - 一个列表的列表,其中每个子列表包含必须通过的 guard,以确保跟踪的图有效。

  • graphs - 一个已成功跟踪的图模块列表。

  • ops_per_graph - 一个列表的列表,其中每个子列表包含图中运行的操作。

要在遇到的第一个图中断处抛出错误,请使用 fullgraph 模式。此模式禁用 TorchDynamo 的 Python 回退,只有当整个程序都可以转换为单个图时才会成功。示例用法:

def toy_example(a, b):
   ...

compiled_toy = torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)

过度的重新编译#

当 TorchDynamo 编译一个函数(或其一部分)时,它会做出关于局部变量和全局变量的某些假设,以便进行编译器优化,并用 guard 来表达这些假设,这些 guard 在运行时会检查特定值。如果任何一个 guard 失败,Dynamo 将重新编译该函数(或其一部分)最多 torch._dynamo.config.recompile_limit 次。如果您的程序达到了缓存限制,您首先需要确定是哪个 guard 失败以及您的程序中哪一部分触发了它。

如果您的程序表现出有限的动态性,您可能可以调整 TorchDynamo 缓存限制,以允许编译和缓存每种变化,但如果缓存限制过高,您可能会发现重新编译的成本超过了任何优化收益。

torch._dynamo.config.recompile_limit = <your desired cache limit>

TorchDynamo 计划支持许多常见的动态张量形状,例如变化的批量大小或序列长度。它不计划支持秩动态性。在此期间,可以设置特定的缓存限制并结合分桶技术,以实现可接受的重新编译次数来处理一些动态模型。

精度调试#

如果您设置环境变量 TORCHDYNAMO_REPRO_LEVEL=4,精度问题也可以被最小化。它以类似 git bisect 的模型运行,并且完整的重现可能类似 TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4。我们需要这样做是因为下游编译器将生成代码,无论是 Triton 代码还是 C++ 后端,这些下游编译器的数值可能在细微之处有所不同,但却可能对您的训练稳定性产生巨大影响。因此,精度调试器对于我们检测我们代码生成或后端编译器中的错误非常有用。

如果您想确保随机数生成在 torch 和 triton 之间保持一致,您可以启用 torch._inductor.config.fallback_random = True

扩展调试#

可以通过使用以下实验性标志启用扩展调试。

TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED - 如果 guard 的字符串表示与此标志值匹配,则提供扩展的调试信息。例如,将其设置为“Ne(s0, 10)”将在发出 guard 时生成完整的 Python 和 C++ 回溯。 TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL - 当分配某个特定符号时,提供扩展的调试信息。例如,将其设置为“u2”将在创建此符号时生成完整的 Python 和 C++ 回溯。 TORCHDYNAMO_EXTENDED_DEBUG_CPP - 为所有扩展调试设置以及错误提供扩展的调试信息(C++ 回溯)。例如,将其设置为“1”。C++ 回溯速度很慢且非常冗长,因此默认情况下不包含在扩展调试中。

冷启动时间测量和缓存损坏调试#

为了测量冷启动编译时间或调试缓存损坏,可以传递 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 或设置 torch.compiler.config.force_disable_caches = True,这将覆盖任何其他缓存配置选项并禁用所有编译时间缓存。