评价此页

torch.compile 故障排除#

创建于:2022年11月28日 | 最后更新于:2025年6月10日

您尝试在 PyTorch 模型上使用 torch.compile 来提高其性能,但效果不如预期。可能是性能没有提升,出现崩溃,或者编译时间太长。本文提供了提示、解决方法和调试工具,以帮助您克服这些挑战。

内容

设定预期#

torch.compile 被设计为一款通用的 PyTorch 编译器。与之前的编译器解决方案 TorchScript 不同,torch.compile 需要更少的代码更改,这意味着模型通常不需要从头开始重写。它还能更优雅地处理不支持的代码 - 不支持的代码会导致优化机会的丢失,而不是崩溃。

理想情况下,您可以直接将 torch.compile 应用于任何 PyTorch 模型,并享受自动加速。然而,实际上,代码的复杂性可能导致以下三种情况之一:

  1. torch.compile 无缝工作,提供加速。

  2. 需要一些代码修改。torch.compile 不会崩溃或花费过多时间,但您可能看不到显著的性能提升。

  3. 需要进行大量代码更改。

我们预计大多数代码将属于情况 (1) 和 (2)。本文档提供了按参与级别排列的提示,以帮助解决情况 (2) 中的代码问题。

编译时间#

torch.compile 作为即时编译器运行,因此首次运行或前几次运行编译后的函数时,性能会明显变慢。重新编译(在某些条件下会发生,详见下文)也会使运行变慢。各种 torch.compile 组件会缓存结果,以减少未来调用的编译时间,即使是在不同的进程中。冷启动(未缓存)编译时间通常需要几秒到几分钟(对于常见或已基准测试过的模型)。更大的模型可能需要 30 分钟到几个小时。

术语#

以下术语与解决 torch.compile 问题有关。

图中断 (Graph break)#

torch.compile 会追踪您的代码,并尝试将您的 PyTorch 代码捕获到一个单一的 PyTorch 操作计算图中(FX 图)。然而,这并非总是可能的。当遇到无法追踪的代码时,会发生“图中断”。图中断包括编译到目前为止已确定的 FX 图,运行不支持的代码,然后在不支持的代码之后使用新的 FX 图恢复追踪。由于计算图被分割,我们失去了优化机会,因此模型代码应尽可能避免图中断。图中断会发生在诸如以下情况:

  • 数据依赖的 if 语句

  • 许多 Python 内置函数

  • C 函数

下面是一个由于 Python 内置库中的 copy.deepcopy 函数导致的图中断示例(确切输出可能有所不同)。

import torch

@torch.compile
def fn(x):
    x = x + 1
    with open("test.txt", "r") as f:
        return x + len(f.read())

fn(torch.ones(3, 3))
$TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in fn
    with open("test.txt", "r") as f:
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 635, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2414, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 962, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 997, in call_function
    return handler(tx, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 831, in <lambda>
    return lambda *args: unimplemented(error_msg)
                        ^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False

Guard(守卫)#

torch.compile 在追踪代码时会做一些关于运行时值的假设。在追踪过程中,我们会生成“guards”(守卫),它们是对这些假设的运行时检查。Guards 会在未来调用编译后的函数时运行,以确定是否可以重用先前编译的代码。运行时检查的例子包括常量值、类型和对象 ID。

下面是一个生成的 guards 示例。TENSOR_MATCH guard 检查输入的类型、设备、dtype、形状等。

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
$ TORCH_LOGS="guards" python playground.py
GUARDS:

TREE_GUARD_MANAGER:
+- RootGuardManager
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:471 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
| +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1])  # return x + 1  # playground.py:6 in fn
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # return x + 1  # playground.py:6 in fn

重新编译#

如果对于先前编译代码的所有实例,guards 都失败了,那么 torch.compile 就必须“重新编译”该函数,需要再次追踪原始代码。

在下面的示例中,由于检查张量参数形状的 guard 失败,因此需要重新编译。

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/0: tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

动态形状#

torch.compile 最初假设张量形状是静态/恒定的,并基于这些假设进行 guarded。通过使用“动态形状”,我们可以让 torch.compile 生成可以接受不同形状张量输入的编译代码——我们避免了每次形状不同时都重新编译。默认情况下,自动动态形状是启用的 torch.compile(dynamic=None) - 如果因形状不匹配而编译失败,则会尝试使用动态形状进行重新编译。动态形状也可以完全启用 dynamic=True 或禁用 dynamic=False

下面,我们启用了动态形状,并注意到我们不再需要重新编译。

import torch

@torch.compile(dynamic=True)
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="dynamic,recompiles" python playground.py
create_symbol s0 = 3 for L['x'].size()[0] [2, int_oo] at playground.py:5 in fn (_dynamo/variables/builder.py:2718 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
produce_guards
produce_guards

有关动态形状的更多信息,请参阅 动态形状手册

日志工具#

tlparse / TORCH_TRACE#

tlparse / TORCH_TRACE 是一组工具,它们生成的编译报告如下所示: https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html

收集 trace 非常容易。要收集 trace,请使用以下命令运行您的复现命令

TORCH_TRACE="/tmp/tracedir" python foo.py
pip install tlparse
tlparse /tmp/tracedir

此方法即使在您运行分布式作业时也有效,可为每个 rank 提供一个 trace。它将打开您的浏览器,显示类似于上面生成的 HTML。如果您要报告一个复杂问题的 bug,而您没有独立的复现方法,您仍然可以通过附加在 /tmp/tracedir 中生成的 trace 日志来极大地帮助 PyTorch 开发人员。

警告

trace 日志包含您的所有模型代码。如果您的模型是敏感的,请不要共享 trace 日志。trace 日志不包含权重。

tlparse 的输出主要面向 PyTorch 开发人员,其日志格式易于上传和在 GitHub 上共享。但是,作为非 PyTorch 开发人员,您仍然可以从中提取有用的信息。我们建议从报告中的内联帮助文本开始,它解释了报告的内容。以下是一些您可以从 tlparse 中获得的见解

  • 通过查看堆栈树,可以了解编译了哪些模型代码?如果您不熟悉正在编译的代码库,这尤其有用!

  • 有多少个图中断 / 不同的编译区域?(每个不同的编译都是自己有颜色编码的块,例如 [0/0])。潜在的图中断帧是浅绿色的 [2/4]。如果有很多帧,那很可疑,这表明您遇到了一些灾难性的图中断,或者您的代码与 torch.compile 的匹配度不高。

  • 我重新编译了某个帧多少次?重新编译了很多次的东西看起来会像:[10/0] [10/1] [10/2] — 如果某个东西被重新编译了很多次,那非常可疑,值得深入研究,即使它不是您问题的根本原因。

  • 是否存在编译错误?发生错误的帧看起来会像 [0/1]

  • 对于给定的帧,我生成了哪些中间编译器产品?例如,您可以查看生成的高级 FX 图或生成的 Triton 代码。

  • 对于某个特定的帧是否有相关信息?您可以在 compilation_metrics 中找到它们。

TORCH_LOGS#

您可以使用 TORCH_LOGS 环境变量选择性地启用 torch.compile 堆栈的某些部分进行日志记录。TORCH_LOGS 实际上是 tlparse 日志的来源。TORCH_LOGS 环境变量的格式如下所示

TORCH_LOGS="<option1>,<option2>,..." python foo.py

有用的高级选项包括

  • graph_breaks:记录用户代码中图中断的位置以及图中断的原因

  • guards:记录生成的 guard

  • recompiles:记录哪个函数被重新编译以及导致重新编译的 guard 失败

  • dynamic:记录与动态形状相关的日志

此外,您还可以使用 torch._logging.set_logs 以编程方式设置日志选项

import logging
torch._logging.set_logs(graph_breaks=True)
...

更多 TORCH_LOGS 选项请参见 TORCH_LOGS 选项摘要。有关所有选项的完整列表,请参见 torch._loggingtorch._logging.set_logs

tlparse 与 TORCH_LOGS 的区别#

通常,我们建议在遇到问题时首先使用 tlparsetlparse 非常适合调试大型模型并对模型的编译方式获得高层次的概览。另一方面,对于小型示例和精细的调试细节,当我们已经了解 torch.compile 的哪个组件导致问题时,首选 TORCH_LOGS

简单的解决方法#

在此,我们描述了一些解决 torch.compile 问题的解决方法,这些方法涉及小的代码修改或更改一些 torch.compile 设置。

何处应用 torch.compile?#

我们建议将 torch.compile 应用于最高层级的函数,该函数不会导致过多问题。通常,这是您的 train 或 eval 步骤,包含优化器但不包含循环,您的顶层 nn.Module,或某些子 nn.Moduletorch.compile 对分布式包装模块(如 DDP 或 FSDP)的处理效果不佳,因此请考虑将 torch.compile 应用于传递给包装器的内部模块。

# inference
model = ...
opt_model = torch.compile(model)

for _ in range(N_ITERS):
    inp = ...
    out = opt_model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())

@torch.compile
def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

for _ in range(N_ITERS):
    inp = ...
    train(model, inp)
# DistributedDataParallel
model = ...
opt_model = torch.compile(model)
model_ddp = DistributedDataParallel(opt_model, ...)

for _ in range(N_ITERS):
    inp = ...
    out = model_ddp(inp)

禁用和抑制错误#

对于某些模型架构,模型中存在一些特别难以编译的部分——要么是图中断很多,要么是崩溃。您可能希望显式禁用这些有问题的模型部分,以便将 torch.compile 应用于可以正常工作的那些部分。您可以通过使用 @torch.compiler.disable 装饰器来实现此目的。当 torch.compile 尝试调用已禁用的函数时,它会中断图并跳过禁用函数的跟踪,在调用后恢复跟踪。默认情况下,从已禁用函数进行的所有递归调用也会被禁用。使用 recursive=False 选项可为递归调用启用编译。

def bad1_inner(...):
    # skipped

@torch.compiler.disable
def bad1_outer(...):
    # skipped
    bad1_inner(...)

def bad2_inner(...)
    # traced

@torch.compiler.disable(recursive=False)
def bad2_outer(...):
    # skipped
    bad2_inner(...)

@torch.compile
def fn(...):
    # graph break
    bad1_outer(...)
        ...
    # graph break
    bad2_outer(...)

例如,我们使用 torch.compiler.disable 来禁用推荐模型中稀疏架构上的 torch.compile,因为稀疏架构难以编译。预处理和日志记录函数是其他导致大量图中断且不从编译中获益的函数示例。

如果您遇到编译器崩溃,并且希望继续操作,您可以设置 torch._dynamo.config.suppress_errors = True。当编译器崩溃时,我们将跳过跟踪该函数并稍后重试。这不是最佳实践——最好最终根据需要手动添加禁用注解。

解决图中断#

为了最大化优化机会,减少图中断的数量很重要。回想一下,您可以使用 tlparseTORCH_LOGS="graph_breaks" 查看正在发生的图中断。通常,图中断是由以下原因之一引起的

  1. 您试图做一些根本无法跟踪的事情,例如数据依赖的控制流。

  2. 您试图做一些尚不支持的事情。例如,我们目前对跟踪使用内置 Python inspect 模块的代码的支持有限。

  3. 您的代码有错误。例如,您可能尝试使用错误的参数数量调用函数。

图中断日志会告诉您用户代码的位置和图中断的原因。不幸的是,许多图中断如果没有对 Dynamo 的更深入了解,是无法采取行动的。甚至很难确定这三个原因中哪一个是您图中断的真正原因。我们正在努力使图中断消息更具可操作性。

此外,丢失优化机会的影响在图中断之间是不同的。例如,发生在模型 forward 中间的图中断可能比发生在 forward 开始处的预处理部分中的图中断产生更负面的影响。因此,防止每一个中断并不重要,而是要防止那些导致显著性能下降的中断。

如果图中断消息没有提出任何操作建议,您怀疑您的图中断的原因是 (2),并且您认为该图中断导致了性能下降,那么请将该图中断作为问题报告。如果一个函数有很多图中断,请考虑禁用该函数的编译,因为图中断的开销成本可能会变得过高。

以下是一些常见的图中断及其解决方法。

数据依赖的操作#

torch.compile 在数据依赖的操作上会中断图,例如数据依赖的控制流(if 语句、带有张量的循环)和直接张量数据访问(.item.data_ptr)。

import torch

@torch.compile
def fn(x):
    y = x.sum()
    if y > 0:
        return x + y.item()
    return x - y.item()

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:6
Reason: Data-dependent jump
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 6, in fn
    if y > 0:

Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: Tensor.item
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in torch_dynamo_resume_in_fn_at_6
    return x + y.item()
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 616, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2288, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2282, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 838, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py", line 1038, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method
    result = handler_method(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 773, in method_item
    unimplemented("Tensor.item")
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 304, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item

这些图中断的通用解决方法是避免执行数据依赖的操作。一些具体的解决方法是

  • 如果您的控制流实际上不依赖于数据值,请考虑修改您的代码以在常量上执行控制流。

# old
x = torch.randn(3, 3)
@torch.compile
def fn(y):
    if x.sum() > 0:
        return y + x
    else:
        return y - x

# new
x = torch.randn(3, 3)
cond = (x.sum() > 0).item()
@torch.compile
def fn(y):
    if cond:
        return y + x
    else:
        return y - x
# old
@torch.compile
def fn(x):
    if x.sum() > 0:
        return x + 1
    return x - 1

# new
@torch.compile
def fn(x):
    return torch.cond(
        x.sum() > 0,
        lambda x: x + 1,
        lambda x: x - 1,
        (x,),
    )
  • 如果您有 .item() 调用,请尝试 torch._dynamo.config.capture_scalar_outputs = TrueTORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1

  • 将函数中有问题的部分包装在自定义 op 中

自定义 op#

如果您有 torch.compile 难以跟踪的代码,无论是由于缺少支持还是根本不兼容,您都可以考虑将有问题的代码包装在自定义 op 中。

自定义 op 需要一些额外的工作才能使其与 torch.compile 兼容。有关更多详细信息,请参见 https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html

打印#

打印/日志记录/发出警告将导致图中断。如果您有一个函数进行大量日志记录调用,例如记录训练迭代数据的函数,请考虑对其应用 torch.compiler.disable

或者,您可以尝试使用 torch._dynamo.config.reorderable_logging_functions。此配置用于重新排序日志函数,以便在跟踪函数的末尾调用它们,从而避免图中断。但是,如果发生突变等情况,日志内容可能会有所不同。

import torch

torch._dynamo.config.reorderable_logging_functions.add(print)

@torch.compile
def fn(x):
    x += 1
    print("log!")
    return torch.sin(x)

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
log!

错误的代码#

您的代码可能是错误的,或者遇到了 torch.compile 之外的错误。在下面的代码中,我们在 torch.sin 调用中通过提供额外的参数而犯了一个拼写错误。

import torch

@torch.compile
def fn(x):
    y = torch.sin(x, x)
    return y

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:5
Reason: Unsupported: TypeError <built-in method sin of type object at 0x7fd6fd764600>: sin() takes 1 positional argument but 2 were given
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 5, in fn
    y = torch.sin(x, x)
...

从日志中很难判断错误是由于您的代码引起的还是由于 torch.compile 的 bug。为了区分,我们建议尝试在不使用 torch.compile 的情况下运行您的代码,看看是否仍然出现错误。

处理重新编译#

您可以使用 tlparseTORCH_LOGS=recompiles 查看重新编译及其原因。

动态形状是否已启用?#

由于形状不匹配而导致的重新编译形式为

tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

确保 torch.compiledynamic 选项未设置为 False。默认选项 dynamic=None 仅在首次编译后尝试动态形状。您可以将 dynamic=True 设置为尽可能多地进行动态编译。

有关动态形状的更多信息,请参阅 动态形状手册

更改缓存大小限制#

函数可以被重新编译的次数是有限制的,由 torch._dynamo.config.recompile_limittorch._dynamo.config.accumulated_recompile_limit 决定。如果其中任何一个限制被超过,我们将不再尝试重新编译该函数,而是改为急切地运行该函数。torch.compile 还会发出一个包含受影响函数和哪个限制被命中的警告。在下面的示例中,每次函数调用都会尝试重新编译。当我们达到缓存大小限制(8)时,我们停止尝试重新编译。

import torch

@torch.compile(dynamic=False)
def fn(x):
    return x + 1

for i in range(1, 10):
    fn(torch.ones(i))
$ python playground.py
torch._dynamo hit config.recompile_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:5)
    last reason: 0/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 9

如果您知道重新编译的次数有一个合理的常量上限,您可以提高缓存大小限制。如果重新编译的成本超过了编译的好处,那么您可以考虑降低缓存大小限制。

用张量包装常量#

默认情况下,int / float 变量被视为常量并以此方式进行 guard。在下面的示例中,我们每次函数调用都有一个重新编译。

import torch

@torch.compile
def fn(x, c):
    return x + c

for i in range(1, 10):
    fn(torch.ones(i), 0.5 + i)
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/7: L['c'] == 8.5
    - 0/6: L['c'] == 7.5
    - 0/5: L['c'] == 6.5
    - 0/4: L['c'] == 5.5
    - 0/3: L['c'] == 4.5
    - 0/2: L['c'] == 3.5
    - 0/1: L['c'] == 2.5
    - 0/0: L['c'] == 1.5
torch._dynamo hit config.recompile_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:3)
    last reason: 0/0: L['c'] == 1.5

特别是,对于 LR 调度器,使用常量初始化会导致重新编译

import torch

mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)

@torch.compile
def fn(inp):
    opt.zero_grad(True)
    out = mod(inp).sum()
    out.backward()
    opt.step()
    sched.step()

for i in range(1, 10):
    fn(torch.ones(3, 3))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function step in /data/users/williamwen/pytorch/torch/optim/adam.py:189
    triggered by the following guard failure(s):
    - 3/7: L['self'].param_groups[0]['lr'] == 0.004782969000000002
    - 3/6: L['self'].param_groups[0]['lr'] == 0.005314410000000002
    - 3/5: L['self'].param_groups[0]['lr'] == 0.005904900000000002
    - 3/4: L['self'].param_groups[0]['lr'] == 0.006561000000000002
    - 3/3: L['self'].param_groups[0]['lr'] == 0.007290000000000001
    - 3/2: L['self'].param_groups[0]['lr'] == 0.008100000000000001
    - 3/1: L['self'].param_groups[0]['lr'] == 0.009000000000000001
    - 3/0: L['self'].param_groups[0]['lr'] == 0.01
torch._dynamo hit config.recompile_limit (8)
    function: 'step' (/data/users/williamwen/pytorch/torch/optim/adam.py:189)
    last reason: 3/0: L['self'].param_groups[0]['lr'] == 0.01

在这两个示例中,我们都可以用张量包装浮点变量以防止重新编译。

# first example
for i in range(1, 10):
    fn(torch.ones(i), torch.tensor(0.5 + i))

# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))

报告问题#

如果上述解决方法不足以使 torch.compile 工作,那么您应该考虑向 PyTorch 报告问题。但您可以做一些事情来使我们的工作更轻松。

消融#

使用 torch.compilebackend= 选项检查 torch.compile 堆栈的哪个组件是导致问题的组件。特别是,尝试

  • torch.compile(fn, backend="eager"),它只运行 TorchDynamo,即 torch.compile 的图捕获组件。

  • torch.compile(fn, backend="aot_eager"),它运行 TorchDynamo 和 AOTAutograd,它在编译期间另外生成后向图。

  • torch.compile(fn, backend="aot_eager_decomp_partition"),它运行 TorchDynamo 和 AOTAutograd 以及算子分解/分区。

  • torch.compile(fn, backend="inductor"),它运行 TorchDynamo、AOTAutograd 和 TorchInductor,这是生成编译内核的后端 ML 编译器。

如果您仅在使用 Inductor 后端时失败,您还可以测试各种 Inductor 模式

  • torch.compile(fn, backend="inductor", mode="default")

  • torch.compile(fn, backend="inductor", mode="reduce-overhead")

  • torch.compile(fn, backend="inductor", mode="max-autotune")

您还可以检查动态形状是否在任何后端导致问题

  • torch.compile(fn, dynamic=True)(始终使用动态形状)

  • torch.compile(fn, dynamic=False)(从不使用动态形状)

  • torch.compile(fn, dynamic=None)(自动动态形状)

二分法#

您尝试过最新的 nightly 版本吗?过去有效但现在不再有效了吗?您能否通过二分法确定您的问题的第一个 nightly 版本?二分法对于性能、准确性或编译时间回归特别有用,在这些情况下,问题来源并不明显。

创建复现器#

创建复现器需要大量工作,如果您没有时间来完成它,那是完全可以的。但是,如果您是一个积极的用户,不熟悉 torch.compile 的内部机制,创建独立的复现器将极大地影响我们修复 bug 的能力。没有复现器,您的 bug 报告必须包含足够的信息供我们识别问题的根本原因并从头开始编写复现器。

以下是可用复现器的列表,按首选程度从高到低排序

  1. 独立的、小的复现器:一个没有外部依赖项的脚本,代码行数少于 100 行,运行时可复现问题。

  2. 独立的、大的复现器:即使它很大,独立性也是一个巨大的优势!

  3. 具有可管理依赖项的非独立复现器:例如,如果您在运行脚本后通过 pip install transformers 复现了问题,那是可以管理的。我们很可能可以运行它并进行调查。

  4. 需要大量设置的非独立复现器:这可能涉及下载数据集、多个环境设置步骤或需要 Docker 镜像的特定系统库版本。设置越复杂,我们创建环境就越困难。

    注意

    Docker simplifies setup but complicates changes to the environment, so it's not a perfect solution, though we'll use it if necessary.
    

从某种程度上说,一个可以在单个进程中运行的复现器比一个需要多进程训练的复现器要好(但同样,如果您只有一个多进程复现器,我们也会接受!)。

此外,以下是您可以在问题中检查的方面的一个非详尽列表,您可以在复现器中尝试复现这些方面

  • Autograd。您的张量输入是否具有 requires_grad=True?您是否在输出上调用了 backward()

  • 动态形状。您是否设置了 dynamic=True?或者您是否多次运行测试代码并改变形状?

  • 自定义算子。真实工作流程中是否涉及自定义算子?您是否可以使用 Python 自定义算子 API 复现其某些重要特征?

  • 配置。您是否设置了所有相同的配置?这包括 torch._dynamo.configtorch._inductor.config 设置,以及 torch.compile 的参数,如 backend / mode

  • 上下文管理器。您是否复现了任何活动的上下文管理器?这可能是 torch.no_grad、自动混合精度、TorchFunctionMode / TorchDispatchMode、激活检查点、编译后的 autograd 等。

  • 张量子类。是否涉及张量子类?

Minifier#

minifier 是一个早期的 torch.compile 工具,它接收一个在尝试运行或编译时崩溃的 FX 图,找到一个同样会崩溃的子图,并输出执行该子图操作的代码。本质上,minifier 找到了特定类别的 torch.compile 相关崩溃的最小复现。这假设我们能够成功地跟踪代码。

不幸的是,如今大多数时候,minifier 无法按预期工作,可能需要其他方法。这可能是因为这类可以通过自动复现的 bug 通常更容易修复并且已经得到解决,从而留下更多难以轻松复现的复杂问题。但是,尝试使用 minifier 是直接的,因此即使可能不成功,也值得一试。

minifier 的操作说明可以在 这里 找到。如果编译器崩溃,您可以设置 TORCHDYNAMO_REPRO_AFTER="dynamo"TORCHDYNAMO_REPRO_AFTER="aot"aot 选项更有可能成功,尽管它可能无法识别 AOTAutograd 问题。这将生成 repro.py 文件,这可能有助于诊断问题。对于与准确性相关的问题,请考虑设置 TORCHDYNAMO_REPRO_LEVEL=4。请注意,这可能不总是能成功识别有问题的子图。

深入调试#

本节提供了独立调试 torch.compile 问题或更深入理解 torch.compile 堆栈的工具和技术。这些方法比上面介绍的方法更复杂,并且由 PyTorch 开发人员定期用于调试实际的 torch.compile 问题。

以下是堆栈的高层概述

Torch Dynamo Stack

堆栈主要包括三个组件:TorchDynamo、AOTAutograd 和 Inductor。我们的调试策略是首先识别错误发生的组件,然后单独调试该组件。要确定负责该问题的组件,请参阅上面“报告问题”下的“消融”部分。有关调试特定组件的指南,请参阅以下各节。

TorchDynamo#

记录 Dynamo 正在跟踪的内容#

TORCH_LOGS=trace_bytecode 选项使您能够查看 Dynamo 正在跟踪的精确字节码指令,以及 Python 解释器堆栈的符号表示。在遇到图中断或崩溃时,建议检查跟踪的最后几个字节码指令。

您还可以使用 TORCH_LOGS=trace_source 来查看 Dynamo 正在跟踪的源代码行。这与 trace_bytecode 结合使用非常有用,可以查看每个跟踪的字节码指令对应的源代码行。

最后,您可以使用 TORCH_LOGS=graph_code 来查看 Dynamo 跟踪的 FX 图的 Python 代码表示。您可以查看此代码以仔细检查正在跟踪的正确操作。

import torch

def g(x, y):
    return x + y

@torch.compile(backend="eager")
def f(x):
    x = torch.sin(x)
    x = g(x, x)
    return x

f(torch.ones(3, 3))
$ TORCH_LOGS="trace_bytecode,trace_source,graph_code" python playground.py
TRACE starts_line /data/users/williamwen/pytorch/playground.py:6 in f ()
    @torch.compile(backend="eager")
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:8 in f (f)
        x = torch.sin(x)
TRACE LOAD_GLOBAL torch []
TRACE LOAD_ATTR sin [NullVariable(), PythonModuleVariable(<module 'torch' from '/data/users/williamwen/pytorch/torch/__init__.py'>)]
TRACE LOAD_FAST x [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>)]
TRACE CALL 1 [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>), LazyVariableTracker()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:9 in f (f)
        x = g(x, x)
TRACE LOAD_GLOBAL g []
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable()]
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable(), TensorVariable()]
TRACE CALL 2 [NullVariable(), UserFunctionVariable(), TensorVariable(), TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:3 in g (g) (inline depth: 1)
    def g(x, y):
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:4 in g (g) (inline depth: 1)
        return x + y
TRACE LOAD_FAST x []
TRACE LOAD_FAST y [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE RETURN_VALUE None [TensorVariable()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:10 in f (f)
        return x
TRACE LOAD_FAST x []
TRACE RETURN_VALUE None [TensorVariable()]
TRACED GRAPH
===== __compiled_fn_1 =====
/data/users/williamwen/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3][3, 1]cpu"):
        l_x_ = L_x_

        # File: /data/users/williamwen/pytorch/playground.py:8 in f, code: x = torch.sin(x)
        x: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None

        # File: /data/users/williamwen/pytorch/playground.py:4 in g, code: return x + y
        x_1: "f32[3, 3][3, 1]cpu" = x + x;  x = None
        return (x_1,)

在 Dynamo 跟踪处设置断点#

在 Dynamo/用户代码中插入断点有时有助于查看 Dynamo 在跟踪用户代码时的状态。不幸的是,以常规 Python 方式插入断点会导致 TorchDynamo 中的图中断,因此我们无法在预期的断点处查看 Dynamo 的状态。

设置断点的第一种方法是将其插入 Dynamo 源代码中。建议放置断点的三个位置是

  • torch/_dynamo/symbolic_convert.py 中,在函数名与有问题的字节码指令相同的函数处设置断点,例如 def CALL_FUNCTIONdef STORE_ATTR。您可以根据输入有条件地设置断点,例如指令的 argval,或者堆栈顶部的对象名称,因为某些字节码操作码使用频繁。

  • 在图中断或错误起源处设置断点。通常,图中断是从对 unimplemented(...) 的调用发出的。

  • torch/_dynamo/variables/builder.py, function:_wrap 中设置断点。您很可能需要根据输入有条件地设置断点。此函数确定如何符号化地表示给定值。如果您怀疑某个值表示不正确,请考虑在此处设置断点。

设置断点的第二种方法是使用 torch._dynamo.comptime.comptime.breakpoint

from torch._dynamo.comptime import comptime

@torch.compile
def f(...):
    ...
    comptime.breakpoint()
    ...

comptime 断点很方便,因为它允许您在正在跟踪的用户代码中的特定位置检查 Dynamo 状态。它不需要您在 Dynamo 源代码中设置断点或根据变量有条件地设置断点。

当触发 comptime 断点时,您可以执行以下操作

  • ctx.print_bt() 打印用户堆栈跟踪

  • ctx.print_locals() 打印所有当前的局部变量

  • ctx.print_graph() 打印当前跟踪的图

  • ctx.disas() 打印当前跟踪函数的字节码

  • 使用标准的 pdb 命令,例如 bt/u/d/n/s/r——您可以向上移动 pdb 堆栈以检查更多 Dynamo 内部信息

import torch
from torch._dynamo.comptime import comptime

@torch.compile(backend="eager")
def f(x):
    y = x + 1
    comptime.breakpoint()
    y = y + 1
    return y

f(torch.ones(3, 3))
$ python playground.py
--Return--
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_bt()
File "/data/users/williamwen/pytorch/playground.py", line 7, in f
    comptime.breakpoint()

(Pdb) ctx.print_locals()
x = FakeTensor(..., size=(3, 3))
y = FakeTensor(..., size=(3, 3))
(Pdb) bt
...
/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py(826)call_function()
-> self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py(331)call_function()
-> func(ComptimeContext(tx))
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_graph()



def forward(self, L_x_: "f32[3, 3]"):
    l_x_ = L_x_

    # File: /data/users/williamwen/pytorch/playground.py:6 in f, code: y = x + 1
    y: "f32[3, 3]" = l_x_ + 1;  l_x_ = y = None

字节码生成错误#

虽然不常见,但 Dynamo 可能会生成错误的字节码。如果您确定以下情况,这可能会发生

  • 消融分析显示错误发生在 TorchDynamo 级别

  • 错误不是从 TorchDynamo 堆栈帧发出的

  • 错误看起来更像是用户错误而不是 Dynamo 错误,或者是一个分段错误

  • 不使用 torch.compile 时不会发生错误

字节码生成 bug 通常很难修复,我们建议提交问题而不是尝试自行修复。如果您有兴趣查看 Dynamo 生成的字节码,可以使用 TORCH_LOGS=bytecode。您可以在 此处 获得 Dynamo 生成字节码的高层概述。

AOTAutograd#

AOTAutograd 错误通常很难调试——我们建议直接提交问题。AOTAutograd 的日志输出主要有助于查看 Inductor 的输入。

TORCH_LOGS 选项摘要#

有用的 TORCH_LOGS 选项摘要如下

选项

描述

+all

输出所有 torch.compile 组件的调试日志

+dynamo

输出 TorchDynamo 的调试日志

+aot

输出 AOTAutograd 的调试日志

+inductor

输出 TorchInductor 的调试日志

dynamic

输出动态形状的日志

graph_code

输出 Dynamo 生成的 FX 图的 Python 代码

graph_sizes

输出 Dynamo 生成的 FX 图的张量大小

trace_bytecode

输出 Dynamo 正在跟踪的字节码指令以及 Dynamo 正在跟踪的符号解释器堆栈

trace_source

输出 Dynamo 当前正在追踪的源代码行

bytecode

输出 Dynamo 生成的字节码

guards

输出生成的 guard

recompiles

输出重新编译的原因(仅第一个失败的 guard 检查)

recompiles_verbose

输出重新编译发生时所有失败的 guard 检查

aot_graphs

输出 AOTAutograd 生成的图

aot_joint_graphs

输出 AOTAutograd 生成的前向-后向联合图

output_code

输出 Inductor 生成的代码

kernel_code

输出 Inductor 按内核生成的代码

schedule

输出 Inductor 调度日志

perf_hints

输出 Inductor perf hint 日志

fusion

输出 Inductor fusion 日志

有关选项的完整列表,请参阅 torch._loggingtorch._logging.set_logs