评价此页

torch.compile 故障排除#

创建于:2022 年 11 月 28 日 | 最后更新于:2025 年 8 月 14 日

您正尝试在 PyTorch 模型上使用 torch.compile 来提高其性能,但效果不如预期。也许性能没有提升,出现了崩溃,或者编译时间过长。本文提供了技巧、解决方法和调试工具,以帮助您克服这些挑战。

内容

设定预期#

torch.compile 被设计为通用 PyTorch 编译器。与之前的编译器解决方案 TorchScript 不同,torch.compile 需要更少的代码修改,这意味着模型通常不需要从头开始重写。它还能更平稳地处理不支持的代码——不支持的代码会导致优化机会的丢失,而不是崩溃。

理想情况下,人们可以直接将 torch.compile 应用于任何 PyTorch 模型并享受自动加速。然而,在现实中,代码的复杂性可能导致以下三种情况之一:

  1. torch.compile 无缝工作,提供加速。

  2. 需要一些代码修改。 torch.compile 不会崩溃或花费太多时间,但您可能看不到显著的性能提升。

  3. 需要对代码进行大量更改。

我们预计大多数代码将属于情况 (1) 和 (2)。本文档提供了按参与度级别排列的技巧,以帮助解决情况 (2) 中的代码问题。

编译时间#

torch.compile 作为即时编译器运行,因此首次运行或前两次运行已编译函数的速度明显变慢是正常的。在特定条件下(详见下文)可能发生的重新编译也会使运行速度变慢。各种 torch.compile 组件会缓存结果,以减少未来调用的编译时间,即使在不同的进程中也是如此。冷启动(未缓存)的编译时间通常为几秒到几分钟,对于常见或经过基准测试的模型。较大的模型可能需要 30 分钟到几个小时。

术语#

以下术语与 torch.compile 问题故障排除相关。

图中断#

torch.compile 会跟踪您的代码,并尝试将您的 PyTorch 代码捕获到一个 PyTorch 算子的计算图中(FX 图)。然而,这并非总是可能的。当遇到无法跟踪的代码时,就会发生“图中断”。图中断包括编译到目前为止已确定的 FX 图,运行不支持的代码,然后用新的 FX 图从不支持的代码处恢复跟踪。由于计算图被打破,我们失去了优化机会,因此模型代码应尽可能避免图中断。图中断发生在以下情况:

  • 数据依赖的 if 语句

  • 许多 Python 内置函数

  • C 函数

下面是一个由于 Python 内置库中的 copy.deepcopy 函数引起的图中断示例(确切输出可能有所不同)。

import torch

@torch.compile
def fn(x):
    x = x + 1
    with open("test.txt", "r") as f:
        return x + len(f.read())

fn(torch.ones(3, 3))
$TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in fn
    with open("test.txt", "r") as f:
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 635, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2414, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 962, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 997, in call_function
    return handler(tx, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 831, in <lambda>
    return lambda *args: unimplemented(error_msg)
                        ^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False

Guard#

torch.compile 在跟踪代码时会做出一些关于运行时值的假设。在跟踪期间,我们会生成“guards”,这些是这些假设的运行时检查。Guard 在编译函数的未来调用中运行,以确定我们是否可以重用先前编译的代码。运行时检查的例子包括常量值、类型和对象 ID。

下面是一个生成的 Guard 的示例。TENSOR_MATCH Guard 检查输入的类型、设备、dtype、形状等。

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
$ TORCH_LOGS="guards" python playground.py
GUARDS:

TREE_GUARD_MANAGER:
+- RootGuardManager
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:471 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
| +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1])  # return x + 1  # playground.py:6 in fn
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # return x + 1  # playground.py:6 in fn

重新编译#

如果之前编译的代码的所有实例的 Guard 都失败了,那么 torch.compile 必须“重新编译”该函数,需要再次跟踪原始代码。

在下面的示例中,需要重新编译,因为检查张量参数形状的 Guard 失败了。

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/0: tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

动态形状#

torch.compile 最初假设张量形状是静态/不变的,并基于这些假设创建 Guard。通过使用“动态形状”,我们可以让 torch.compile 生成可以接受不同形状张量输入的编译代码——我们避免了每次形状不同时都重新编译。默认情况下,自动动态形状是启用的 torch.compile(dynamic=None) ——如果由于形状不匹配导致编译失败,则会尝试使用动态形状进行重新编译。动态形状也可以完全启用 dynamic=True 或禁用 dynamic=False

下面,我们启用了动态形状,并注意到我们不再需要重新编译。

import torch

@torch.compile(dynamic=True)
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="dynamic,recompiles" python playground.py
create_symbol s0 = 3 for L['x'].size()[0] [2, int_oo] at playground.py:5 in fn (_dynamo/variables/builder.py:2718 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
produce_guards
produce_guards

有关动态形状的更多信息,请参阅 动态形状手册

日志工具#

tlparse / TORCH_TRACE#

tlparse / TORCH_TRACE 是一对工具,它们生成如下所示的编译报告: https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html

收集跟踪非常容易。要收集跟踪,请使用以下命令运行您的重现命令:

TORCH_TRACE="/tmp/tracedir" python foo.py
pip install tlparse
tlparse /tmp/tracedir

这种方法甚至适用于您正在运行分布式作业的情况,为每个 rank 提供一个跟踪。它将在浏览器中打开 HTML,类似于上面生成的 HTML。如果您正在为没有独立重现的复杂问题编写 bug 报告,您仍然可以通过附加在 /tmp/tracedir 中生成的跟踪日志来极大地帮助 PyTorch 开发者。

警告

跟踪日志包含您的所有模型代码。如果您的模型是敏感的,请不要共享跟踪日志。跟踪日志不包含权重。

tlparse 的输出主要面向 PyTorch 开发者,并且日志格式易于上传并在 GitHub 上共享。然而,作为非 PyTorch 开发者,您仍然可以从中提取有用的信息。我们建议从报告中的内联帮助文本开始,它解释了其内容。以下是一些您可以从 tlparse 中获得的见解:

  • 通过查看堆栈树,了解了哪些模型代码被编译了?如果您不熟悉正在编译的代码库,这一点尤其有用!

  • 有多少图中断 / 不同的编译区域?(每个独立的编译都是自己的颜色编码块,如 [0/0])。可能图中断的帧是浅绿色 [2/4]。如果有很多帧,那很可疑,这表明您发生了一些灾难性的图中断,或者您的代码可能不适合 torch.compile

  • 我重新编译了某个帧多少次?反复重新编译的帧会显示为: [10/0] [10/1] [10/2] - 如果某项被反复重新编译,那非常可疑,值得深入研究,即使它不是您问题的根本原因。

  • 是否存在编译错误?出现错误的帧将显示为 [0/1]

  • 为给定帧生成了哪些中间编译器产品?例如,您可以查看高级生成的 FX 图或生成的 Triton 代码。

  • 某个帧是否有相关信息?您可以在 compilation_metrics 中找到它们。

TORCH_LOGS#

您可以使用 TORCH_LOGS 环境变量选择性地启用 torch.compile 堆栈的部分以进行日志记录。TORCH_LOGS 实际上是 tlparse 的日志来源。TORCH_LOGS 环境变量的格式如下:

TORCH_LOGS="<option1>,<option2>,..." python foo.py

有用的高级选项包括:

  • graph_breaks:记录用户代码中图中断的位置及其原因。

  • guards:记录生成的 Guard。

  • recompiles:记录哪个函数被重新编译以及导致重新编译的失败 Guard。

  • dynamic:记录与动态形状相关的日志。

此外,您还可以使用 torch._logging.set_logs 以编程方式设置日志选项。

import logging
torch._logging.set_logs(graph_breaks=True)
...

TORCH_LOGS 的更多选项请参阅 TORCH_LOGS 选项摘要。有关完整选项列表,请参阅 torch._loggingtorch._logging.set_logs

tlparse 与 TORCH_LOGS#

通常,我们建议在遇到问题时首先使用 tlparsetlparse 非常适合调试大型模型并获得模型编译方式的高级概览。另一方面,TORCH_LOGS 更适合小型示例和精细的调试细节,当您已经大致了解是哪个 torch.compile 组件引起问题时。

简单的解决方法#

在这里,我们描述了一些针对 torch.compile 问题的解决方法,这些问题涉及小的代码修改或更改一些 torch.compile 设置。

在哪里应用 torch.compile?#

我们建议将 torch.compile 应用于最高级别的、不会导致过多问题的函数。通常,这是您的训练或评估步骤(带优化器但没有循环),您的顶层 nn.Module,或一些子 nn.Moduletorch.compile 特别不擅长处理分布式包装模块(如 DDP 或 FSDP),因此请考虑将 torch.compile 应用于传递给包装器的内部模块。

# inference
model = ...
opt_model = torch.compile(model)

for _ in range(N_ITERS):
    inp = ...
    out = opt_model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())

@torch.compile
def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

for _ in range(N_ITERS):
    inp = ...
    train(model, inp)
# DistributedDataParallel
model = ...
opt_model = torch.compile(model)
model_ddp = DistributedDataParallel(opt_model, ...)

for _ in range(N_ITERS):
    inp = ...
    out = model_ddp(inp)

禁用和抑制错误#

对于某些模型架构,模型的一部分可能特别难以编译——要么有很多图中断,要么发生崩溃。您可能希望显式禁用这些有问题的模型部分,以便您可以将 torch.compile 应用于可工作的部件。您可以使用 `@torch.compiler.disable` 装饰器来实现此目的。当 torch.compile 尝试调用已禁用的函数时,它会中断图并跳过禁用函数的跟踪,然后在不支持的函数调用后用新的 FX 图恢复跟踪。默认情况下,从被禁用函数发出的所有递归调用也都被禁用。使用 `recursive=False` 选项可以允许递归调用的编译。

def bad1_inner(...):
    # skipped

@torch.compiler.disable
def bad1_outer(...):
    # skipped
    bad1_inner(...)

def bad2_inner(...)
    # traced

@torch.compiler.disable(recursive=False)
def bad2_outer(...):
    # skipped
    bad2_inner(...)

@torch.compile
def fn(...):
    # graph break
    bad1_outer(...)
        ...
    # graph break
    bad2_outer(...)

例如,我们使用 torch.compiler.disable 来禁用推荐模型中稀疏架构上的 torch.compile,因为稀疏架构很难编译。预处理和日志记录函数是导致大量图中断且从编译中获益不大的函数的其他示例。

如果您遇到编译器崩溃并希望继续进行,可以设置 torch._dynamo.config.suppress_errors = True。当编译器崩溃时,我们将跳过函数的跟踪并稍后重试。这不是最佳实践——最好最终手动添加所需的禁用注解。

解决图中断#

为了最大化优化机会,减少图中断的数量很重要。请记住,您可以使用 tlparseTORCH_LOGS="graph_breaks" 来查看正在发生的图中断。通常,图中断是由以下原因之一引起的:

  1. 您正在尝试执行一项根本无法跟踪的操作,例如数据依赖的控制流。

  2. 您正在尝试执行一项尚未支持的操作。例如,我们目前对跟踪使用内置 Python inspect 模块的代码的支持有限。

  3. 您的代码中有错误。例如,您可能尝试使用错误的参数数量调用函数。

图中断日志会告诉您用户代码的位置和图中断的原因。不幸的是,许多图中断需要对 Dynamo 有更深入的了解才能处理。甚至可能难以确定您的图中断的真正原因属于这三个原因中的哪一个。我们正在努力使图中断消息更具可操作性。

此外,图中断对丢失优化机会的影响也不同。例如,发生在模型 forward 中间的图中断可能比发生在 forward 开始处的预处理部分的图中断产生更负面的影响。因此,阻止每一个中断并非至关重要,而是要阻止那些导致性能显著下降的中断。

如果图中断消息没有建议任何操作,您怀疑图中断的原因是 (2),并且您认为图中断导致性能下降,那么请将该图中断报告为一个问题。如果一个函数有很多图中断,请考虑禁用该函数的编译,因为图中断的开销可能会变得过高。

下面是一些常见的图中断及其解决方法。

数据依赖的操作#

torch.compile 在数据依赖的操作上会发生图中断,例如数据依赖的控制流(if 语句、带有张量的循环)和直接的张量数据访问(.item.data_ptr)。

import torch

@torch.compile
def fn(x):
    y = x.sum()
    if y > 0:
        return x + y.item()
    return x - y.item()

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:6
Reason: Data-dependent jump
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 6, in fn
    if y > 0:

Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: Tensor.item
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in torch_dynamo_resume_in_fn_at_6
    return x + y.item()
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 616, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2288, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2282, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 838, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py", line 1038, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method
    result = handler_method(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 773, in method_item
    unimplemented("Tensor.item")
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 304, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item

这些图中断的通用解决方法是避免执行数据依赖的操作。一些具体的解决方法是:

  • 如果您的控制流实际上不依赖于数据值,请考虑修改您的代码以对常量执行控制流。

# old
x = torch.randn(3, 3)
@torch.compile
def fn(y):
    if x.sum() > 0:
        return y + x
    else:
        return y - x

# new
x = torch.randn(3, 3)
cond = (x.sum() > 0).item()
@torch.compile
def fn(y):
    if cond:
        return y + x
    else:
        return y - x
# old
@torch.compile
def fn(x):
    if x.sum() > 0:
        return x + 1
    return x - 1

# new
@torch.compile
def fn(x):
    return torch.cond(
        x.sum() > 0,
        lambda x: x + 1,
        lambda x: x - 1,
        (x,),
    )
  • 如果您有 .item() 调用,请尝试 torch._dynamo.config.capture_scalar_outputs = TrueTORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1

  • 将函数中有问题的部分包装在自定义算子中。

自定义算子#

如果您有 torch.compile 难以跟踪的代码,无论是由于缺少支持还是根本不兼容,您都可以考虑将有问题的代码包装在自定义算子中。

自定义算子需要额外的一些工作才能使其与 torch.compile 兼容。有关更多详细信息,请参阅 https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html

打印#

打印/日志/发出警告将导致图中断。如果您有一个函数进行许多日志记录调用,例如,一个记录有关训练迭代数据的函数,请考虑在该函数上应用 torch.compiler.disable

或者,您可以尝试使用 torch._dynamo.config.reorderable_logging_functions。此配置用于重新排序日志函数,以便它们在跟踪函数的末尾被调用,从而避免图中断。但是,如果发生突变,日志内容可能会有所不同。

import torch

torch._dynamo.config.reorderable_logging_functions.add(print)

@torch.compile
def fn(x):
    x += 1
    print("log!")
    return torch.sin(x)

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
log!

代码错误#

您的代码可能不正确,或者遇到了来自 torch.compile 之外的错误。在下面的代码中,我们在 torch.sin 调用中犯了一个拼写错误,提供了一个额外的参数。

import torch

@torch.compile
def fn(x):
    y = torch.sin(x, x)
    return y

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:5
Reason: Unsupported: TypeError <built-in method sin of type object at 0x7fd6fd764600>: sin() takes 1 positional argument but 2 were given
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 5, in fn
    y = torch.sin(x, x)
...

从日志中很难判断错误是由您的代码引起的,还是由 torch.compile 的 bug 引起的。为了区分,我们建议尝试在没有 torch.compile 的情况下运行您的代码,看看是否仍然出现错误。

处理重新编译#

您可以使用 tlparseTORCH_LOGS=recompiles 查看重新编译及其原因。

是否启用了动态形状?#

由于形状不匹配而导致的重新编译形式为:

tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

确保 torch.compile 的 `dynamic` 选项未设置为 False。默认选项 dynamic=None 只会在第一次编译后尝试动态形状。您可以设置 dynamic=True 以提前尽可能地进行动态编译。

有关动态形状的更多信息,请参阅 动态形状手册

更改缓存大小限制#

函数可以被重新编译的次数是有限制的,由 torch._dynamo.config.recompile_limittorch._dynamo.config.accumulated_recompile_limit 确定。如果任一限制被超过,我们将不再尝试重新编译该函数,而是会以惰性模式运行该函数。 torch.compile 还会发出警告,其中包含受影响的函数以及达到了哪个限制。在下面的示例中,每次函数调用都会导致一次重新编译尝试。当我们达到缓存大小限制(8)时,我们会停止尝试重新编译。

import torch

@torch.compile(dynamic=False)
def fn(x):
    return x + 1

for i in range(1, 10):
    fn(torch.ones(i))
$ python playground.py
torch._dynamo hit config.recompile_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:5)
    last reason: 0/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 9

如果您知道重新编译次数有一个合理的固定上限,您可以提高缓存大小限制。如果重新编译的成本超过了编译的好处,那么您可以考虑降低缓存大小限制。

用张量包装常量#

默认情况下,int / float 变量被视为常量并以此进行 Guard。在下面的示例中,每次函数调用都会导致一次重新编译。

import torch

@torch.compile
def fn(x, c):
    return x + c

for i in range(1, 10):
    fn(torch.ones(i), 0.5 + i)
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/7: L['c'] == 8.5
    - 0/6: L['c'] == 7.5
    - 0/5: L['c'] == 6.5
    - 0/4: L['c'] == 5.5
    - 0/3: L['c'] == 4.5
    - 0/2: L['c'] == 3.5
    - 0/1: L['c'] == 2.5
    - 0/0: L['c'] == 1.5
torch._dynamo hit config.recompile_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:3)
    last reason: 0/0: L['c'] == 1.5

特别是对于 LR 调度器,使用常量进行初始化可能导致重新编译。

import torch

mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)

@torch.compile
def fn(inp):
    opt.zero_grad(True)
    out = mod(inp).sum()
    out.backward()
    opt.step()
    sched.step()

for i in range(1, 10):
    fn(torch.ones(3, 3))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function step in /data/users/williamwen/pytorch/torch/optim/adam.py:189
    triggered by the following guard failure(s):
    - 3/7: L['self'].param_groups[0]['lr'] == 0.004782969000000002
    - 3/6: L['self'].param_groups[0]['lr'] == 0.005314410000000002
    - 3/5: L['self'].param_groups[0]['lr'] == 0.005904900000000002
    - 3/4: L['self'].param_groups[0]['lr'] == 0.006561000000000002
    - 3/3: L['self'].param_groups[0]['lr'] == 0.007290000000000001
    - 3/2: L['self'].param_groups[0]['lr'] == 0.008100000000000001
    - 3/1: L['self'].param_groups[0]['lr'] == 0.009000000000000001
    - 3/0: L['self'].param_groups[0]['lr'] == 0.01
torch._dynamo hit config.recompile_limit (8)
    function: 'step' (/data/users/williamwen/pytorch/torch/optim/adam.py:189)
    last reason: 3/0: L['self'].param_groups[0]['lr'] == 0.01

在以上两个示例中,我们可以将浮点变量包装在张量中以防止重新编译。

# first example
for i in range(1, 10):
    fn(torch.ones(i), torch.tensor(0.5 + i))

# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))

报告问题#

如果上述解决方法不足以使 torch.compile 工作,那么您应该考虑将问题报告给 PyTorch。但是,您可以做一些事情来使我们的生活变得更轻松。

消融#

使用 torch.compile 的 `backend=` 选项检查 torch.compile 堆栈的哪个组件导致了问题。特别是,尝试:

  • torch.compile(fn, backend="eager"),它只运行 torch.compile 的图捕获组件 TorchDynamo。

  • torch.compile(fn, backend="aot_eager"),它运行 TorchDynamo 和 AOTAutograd,后者在编译期间还会生成后向图。

  • torch.compile(fn, backend="aot_eager_decomp_partition"),它运行 TorchDynamo 和 AOTAutograd,并进行算子分解/分区。

  • torch.compile(fn, backend="inductor"),它运行 TorchDynamo、AOTAutograd 和 TorchInductor,这是生成编译内核的后端 ML 编译器。

如果您仅在 Inductor 后端出现问题,您还可以测试各种 Inductor 模式:

  • torch.compile(fn, backend="inductor", mode="default")

  • torch.compile(fn, backend="inductor", mode="reduce-overhead")

  • torch.compile(fn, backend="inductor", mode="max-autotune")

您还可以检查动态形状是否导致任何后端出现问题:

  • torch.compile(fn, dynamic=True)(始终使用动态形状)

  • torch.compile(fn, dynamic=False)(从不使用动态形状)

  • torch.compile(fn, dynamic=None)(自动动态形状)

二分查找#

您是否尝试过最新的 nightly 版本?过去有效的东西现在是否不再有效?您能否二分查找以确定问题的首次出现版本?二分查找对于性能、准确性或编译时间回归尤其有用,因为这些问题很难立即确定其来源。

创建重现器#

创建重现器需要大量工作,如果您没有时间完成它,那也是完全可以接受的。然而,如果您是一位熟悉 torch.compile 内部结构但不熟悉的有积极性的用户,创建独立的重现器对我们修复 bug 的能力有巨大的影响。没有重现器,您的 bug 报告必须包含足够的信息,以便我们确定问题的根源并从头开始编写重现器。

以下是可用重现器的列表,按偏好程度从高到低排序:

  1. 独立的、小的重现器: 一个没有外部依赖项的脚本,代码行数少于 100 行,运行后能重现问题。

  2. 独立的、大的重现器: 即使代码很大,独立性也是一个巨大的优势!

  3. 具有可管理依赖项的非独立重现器: 例如,如果您可以通过运行一个脚本并在 `pip install transformers` 后进行重现,那是可以管理的。我们很可能能够运行它并进行调查。

  4. 需要大量设置的非独立重现器: 这可能涉及下载数据集、多个环境设置步骤或需要 Docker 镜像的特定系统库版本。设置越复杂,我们重现环境就越困难。

    注意

    Docker simplifies setup but complicates changes to the environment, so it's not a perfect solution, though we'll use it if necessary.
    

在某种程度上,一个可以在单个进程中运行的重现器比需要多进程训练的重现器更好(但同样,如果您只有一个多进程重现器,我们会接受!)。

此外,以下是可以在您的问题中检查的方面列表,您可以在重现器中尝试重现这些方面:

  • Autograd。是否有 `requires_grad=True` 的张量输入?是否在输出上调用了 `backward()`?

  • 动态形状。您是否设置了 `dynamic=True`?或者您是否多次使用不同形状运行了测试代码?

  • 自定义算子。在实际工作流程中是否涉及自定义算子?您能否使用 Python 自定义算子 API 重现其某些重要特征?

  • 配置。您是否设置了所有相同的配置?这包括 `torch._dynamo.config` 和 `torch._inductor.config` 设置,以及 `torch.compile` 的参数,如 `backend` / `mode`。

  • 上下文管理器。您是否重现了任何活动的上下文管理器?这可能包括 `torch.no_grad`、自动混合精度、`TorchFunctionMode` / `TorchDispatchMode`、激活检查点、编译的 autograd 等。

  • 张量子类。是否涉及张量子类?

Minifier#

Minifier 是一个早期的 torch.compile 工具,它接受一个在尝试运行或编译时崩溃的 FX 图,找到一个同样崩溃的子图,并输出执行该子图操作的代码。本质上,minifier 找到了某种 torch.compile 相关崩溃的最小重现。这假设我们能够成功跟踪代码。

不幸的是,如今大多数时候,minifier 的效果不佳,可能需要其他方法。这可能是因为以这种方式自动重现的 bug 通常更容易修复并且已经得到解决,剩下更复杂且不易重现的问题。然而,尝试使用 minifier 是很直接的,所以即使它可能不成功,也值得一试。

操作 minifier 的说明可以在此处找到。如果编译器崩溃,您可以设置 TORCHDYNAMO_REPRO_AFTER="dynamo"TORCHDYNAMO_REPRO_AFTER="aot"aot 选项更可能成功,尽管它可能无法识别 AOTAutograd 问题。这将生成 `repro.py` 文件,可能有助于诊断问题。对于准确性相关的问题,请考虑设置 TORCHDYNAMO_REPRO_LEVEL=4。请注意,这可能并不总是能成功识别有问题的子图。

深入调试#

本节提供了独立调试 torch.compile 问题或深入了解 torch.compile 堆栈的工具和技术。这些方法比上面介绍的方法更复杂,并且是 PyTorch 开发者经常用来调试实际 torch.compile 问题的。

下面是堆栈的高级概述:

Torch Dynamo Stack

堆栈由三个主要组件组成:TorchDynamo、AOTAutograd 和 Inductor。我们的调试策略包括首先识别错误发生的组件,然后单独调试该组件。要确定负责该问题的组件,请参阅上面“报告问题”下的“消融”部分。有关调试特定组件的指导,请参阅以下各节。

TorchDynamo#

记录 Dynamo 正在跟踪的内容#

`TORCH_LOGS=trace_bytecode` 选项使您能够查看 Dynamo 正在跟踪的确切字节码指令,以及 Python 解释器堆栈的符号表示。在遇到图中断或崩溃时,建议检查最后几个被跟踪的字节码指令。

您还可以使用 `TORCH_LOGS=trace_source` 来查看 Dynamo 正在跟踪的源代码行。这可以与 `trace_bytecode` 结合使用,以查看每个被跟踪的字节码指令对应的源代码行。

最后,您可以使用 `TORCH_LOGS=graph_code` 来查看表示 Dynamo 跟踪的 FX 图的 Python 代码。您可以查看此代码以仔细检查正在跟踪的正确算子。

import torch

def g(x, y):
    return x + y

@torch.compile(backend="eager")
def f(x):
    x = torch.sin(x)
    x = g(x, x)
    return x

f(torch.ones(3, 3))
$ TORCH_LOGS="trace_bytecode,trace_source,graph_code" python playground.py
TRACE starts_line /data/users/williamwen/pytorch/playground.py:6 in f ()
    @torch.compile(backend="eager")
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:8 in f (f)
        x = torch.sin(x)
TRACE LOAD_GLOBAL torch []
TRACE LOAD_ATTR sin [NullVariable(), PythonModuleVariable(<module 'torch' from '/data/users/williamwen/pytorch/torch/__init__.py'>)]
TRACE LOAD_FAST x [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>)]
TRACE CALL 1 [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>), LazyVariableTracker()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:9 in f (f)
        x = g(x, x)
TRACE LOAD_GLOBAL g []
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable()]
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable(), TensorVariable()]
TRACE CALL 2 [NullVariable(), UserFunctionVariable(), TensorVariable(), TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:3 in g (g) (inline depth: 1)
    def g(x, y):
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:4 in g (g) (inline depth: 1)
        return x + y
TRACE LOAD_FAST x []
TRACE LOAD_FAST y [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE RETURN_VALUE None [TensorVariable()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:10 in f (f)
        return x
TRACE LOAD_FAST x []
TRACE RETURN_VALUE None [TensorVariable()]
TRACED GRAPH
===== __compiled_fn_1 =====
/data/users/williamwen/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3][3, 1]cpu"):
        l_x_ = L_x_

        # File: /data/users/williamwen/pytorch/playground.py:8 in f, code: x = torch.sin(x)
        x: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None

        # File: /data/users/williamwen/pytorch/playground.py:4 in g, code: return x + y
        x_1: "f32[3, 3][3, 1]cpu" = x + x;  x = None
        return (x_1,)

设置 Dynamo 跟踪的断点#

有时在 Dynamo/用户代码中插入断点有助于查看 Dynamo 在跟踪用户代码时的状态。不幸的是,以常规 Python 方式插入断点会导致 TorchDynamo 中的图中断,因此我们无法在计划设置断点的地方查看 Dynamo 的状态。

设置断点的第一种方法是在 Dynamo 源代码中插入。三个推荐的断点位置是:

  • torch/_dynamo/symbolic_convert.py 中,在命名与有问题的字节码指令相同的函数处设置断点,例如 `def CALL_FUNCTION` 和 `def STORE_ATTR`。您可以根据输入有条件地设置断点,例如,指令的 `argval`,或者位于堆栈顶部的对象的名称,因为某些字节码操作码经常被使用。

  • 在图中断或错误起源处设置断点。通常,图中断是从对 `unimplemented(...)` 的调用发出的。

  • torch/_dynamo/variables/builder.py, function:_wrap 中设置断点。您很可能需要根据输入有条件地设置断点。此函数决定如何符号化地表示给定值。如果您怀疑某个值被错误地表示,请考虑在此处设置断点。

插入断点的第二种方法是使用 torch._dynamo.comptime.comptime.breakpoint

from torch._dynamo.comptime import comptime

@torch.compile
def f(...):
    ...
    comptime.breakpoint()
    ...

comptime 断点很方便,因为它允许您在正在跟踪的用户代码的特定位置检查 Dynamo 状态。它不需要您在 Dynamo 源代码中插入断点或根据变量有条件地设置断点。

当触发 comptime 断点时,您可以执行以下操作:

  • ctx.print_bt() 打印用户堆栈跟踪。

  • ctx.print_locals() 打印所有当前局部变量。

  • ctx.print_graph() 打印当前跟踪的图。

  • ctx.disas() 打印当前跟踪函数的字节码。

  • 使用标准的 `pdb` 命令,例如 `bt/u/d/n/s/r` - 您可以向上遍历 `pdb` 堆栈以检查更多 Dynamo 内部。

import torch
from torch._dynamo.comptime import comptime

@torch.compile(backend="eager")
def f(x):
    y = x + 1
    comptime.breakpoint()
    y = y + 1
    return y

f(torch.ones(3, 3))
$ python playground.py
--Return--
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_bt()
File "/data/users/williamwen/pytorch/playground.py", line 7, in f
    comptime.breakpoint()

(Pdb) ctx.print_locals()
x = FakeTensor(..., size=(3, 3))
y = FakeTensor(..., size=(3, 3))
(Pdb) bt
...
/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py(826)call_function()
-> self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py(331)call_function()
-> func(ComptimeContext(tx))
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_graph()



def forward(self, L_x_: "f32[3, 3]"):
    l_x_ = L_x_

    # File: /data/users/williamwen/pytorch/playground.py:6 in f, code: y = x + 1
    y: "f32[3, 3]" = l_x_ + 1;  l_x_ = y = None

字节码生成错误#

虽然不常见,但 Dynamo 可能会生成错误的字节码。如果您确定以下几点,这可能会发生:

  • 消融表明错误发生在 TorchDynamo 级别。

  • 错误不是从 TorchDynamo 堆栈帧发出的。

  • 错误看起来更像是用户错误而不是 Dynamo 错误,或者是一个段错误。

  • 没有 torch.compile 时错误不会发生。

字节码生成 bug 通常很难修复,我们建议提交一个问题而不是自己尝试修复。如果您有兴趣查看 Dynamo 生成的字节码,可以使用 TORCH_LOGS=bytecode。您可以在此处查看 Dynamo 生成的字节码的高级概述。

AOTAutograd#

AOTAutograd 错误通常难以调试 - 我们建议直接提交一个问题。AOTAutograd 的日志输出主要有助于查看 Inductor 的输入是什么。

TORCH_LOGS 选项摘要#

有用的 TORCH_LOGS 选项摘要如下:

选项

描述

+all

输出所有 torch.compile 组件的调试日志。

+dynamo

输出 TorchDynamo 的调试日志。

+aot

输出 AOTAutograd 的调试日志。

+inductor

输出 TorchInductor 的调试日志。

dynamic

输出动态形状相关的日志。

graph_code

输出 Dynamo 生成的 FX 图的 Python 代码。

graph_sizes

输出 Dynamo 生成的 FX 图的张量大小。

trace_bytecode

输出 Dynamo 正在跟踪的字节码指令以及 Dynamo 正在维护的符号解释器堆栈。

trace_source

输出 Dynamo 当前正在跟踪的原始源代码行。

bytecode

输出 Dynamo 生成的字节码。

guards

输出生成的 Guard。

recompiles

输出重新编译原因(仅第一个失败的 Guard 检查)。

recompiles_verbose

输出重新编译时所有失败的 Guard 检查。

aot_graphs

输出 AOTAutograd 生成的图。

aot_joint_graphs

输出 AOTAutograd 生成的前向-后向联合图。

output_code

输出 Inductor 生成的代码。

kernel_code

输出 Inductor 按内核生成的代码。

schedule

输出 Inductor 调度日志。

perf_hints

输出 Inductor 性能提示日志。

fusion

输出 Inductor fusion 日志。

有关完整选项列表,请参阅 torch._loggingtorch._logging.set_logs