评价此页

常见问题#

创建于:2025 年 6 月 16 日 | 最后更新于:2025 年 6 月 16 日

作者Mark Saroufim

torch.compile 支持训练吗?#

torch.compile 支持训练,它使用 AOTAutograd 来捕获反向传播。

  1. .forward() 图和 optimizer.step() 由 TorchDynamo 的 Python evalframe 前端捕获。

  2. 对于 torchdynamo 捕获的每个 .forward() 片段,它使用 AOTAutograd 生成一个反向传播图片段。

  3. 每个前向和后向图对(可选地)通过最小割进行分区,以在训练前向和后向之间保存最小状态。

  4. 前向和后向对被包装在 autograd.function 模块中。

  5. 用户调用 .backward() 的代码仍然会触发 eager 的 autograd 引擎,该引擎将每个*已编译的反向传播*图作为一个操作来运行,同时也会运行任何未编译的 eager 操作的 .backward() 函数。

您是否支持分布式代码?#

torch.compile 支持 DistributedDataParallel (DDP)。其他分布式训练库的支持正在考虑中。

分布式代码与 dynamo 交互具有挑战性的主要原因在于,AOTAutograd 会展开前向和后向传播,并为后端优化提供两个图。这对于分布式代码来说是个问题,因为我们希望理想情况下能够将通信操作与计算重叠。Eager PyTorch 通过多种方式实现这一点,例如 DDP/FSDP,通过 autograd hook、module hook 以及模块状态的修改/突变。在 naive 应用 dynamo 的情况下,由于 AOTAutograd 编译函数与 dispatch hook 的交互方式,本应在反向传播过程中紧随操作运行的 hook 可能会被延迟到整个已编译的反向传播操作区域之后。

优化 DDP 与 Dynamo 的基本策略在 distributed.py 中进行了概述,其主要思想是在 DDP bucket 边界 上进行图拆分。

当 DDP 中的每个节点需要与其节点同步权重时,它会将其梯度和参数组织成 buckets,这可以减少通信时间,并允许节点广播其一部分梯度给其他等待的节点。

分布式代码中的图拆分意味着您可以期望 dynamo 及其后端能够优化分布式程序的计算开销,但无法优化其通信开销。图拆分可能会干扰编译加速,因为减少的图大小会剥夺编译器融合的机会。然而,随着图大小的增加,收益会递减,因为当前大多数计算优化都是局部融合。所以实际上这种方法可能就足够了。

我还需要导出整个图吗?#

对于绝大多数模型,您可能不需要,并且可以按原样使用 torch.compile。但有些情况需要完整的图,您可以通过简单地运行 torch.compile(..., fullgraph=True) 来确保完整的图。这些情况包括:

  • 大规模训练运行,例如 250K+,需要流水线并行和其他高级分片策略。

  • 推理优化器,如 TensorRTAITemplate,它们依赖于比训练优化器更激进的融合。

  • 移动端训练或推理。

未来的工作将包括将通信操作跟踪到图中,协调这些操作与计算优化,以及优化通信操作。

为什么我的代码崩溃了?#

如果您的代码在未启用 torch.compile 时运行正常,但在启用它后开始崩溃,那么最重要的第一步是确定故障发生在堆栈的哪个部分。为了解决这个问题,请按照以下步骤操作,并且仅在前一个步骤成功后尝试下一步。

  1. torch.compile(..., backend="eager") 仅运行 TorchDynamo 前向图捕获,然后使用 PyTorch 运行捕获的图。如果这失败了,那么 TorchDynamo 就存在问题。

  2. torch.compile(..., backend="aot_eager") 运行 TorchDynamo 捕获前向图,然后使用 AOTAutograd 跟踪反向传播图,而无需任何额外的后端编译器步骤。然后,PyTorch eager 将用于运行前向和后向图。如果这失败了,那么 AOTAutograd 就存在问题。

  3. torch.compile(..., backend="inductor") 运行 TorchDynamo 捕获前向图,然后使用 AOTAutograd 跟踪反向传播图,并使用 TorchInductor 编译器。如果这失败了,那么 TorchInductor 就存在问题。

为什么编译很慢?#

  • Dynamo 编译— TorchDynamo 内置了一个统计函数,用于收集和显示每个编译阶段花费的时间。这些统计信息可以通过在执行 torch._dynamo 后调用 torch._dynamo.utils.compile_times() 来访问。默认情况下,它会返回一个按名称列出 TorchDynamo 各函数所花费编译时间的字符串表示。

  • Inductor 编译— TorchInductor 具有内置的统计和跟踪函数,用于显示每个编译阶段所花费的时间、输出代码、输出图可视化和 IR dump。env TORCH_COMPILE_DEBUG=1 python repro.py。这是一个调试工具,旨在帮助您更容易地调试/理解 TorchInductor 的内部机制,其输出将类似于 。该调试跟踪中的每个文件都可以通过 torch._inductor.config.trace.* 启用/禁用。由于剖析图和图表的生成成本很高,因此它们默认是禁用的。请参阅 示例调试目录输出 以获取更多示例。

  • 过度重新编译 当 TorchDynamo 编译一个函数(或其一部分)时,它会针对局部变量和全局变量做出某些假设,以允许编译器进行优化,并将这些假设表示为在运行时检查特定值的 guards。如果任何 guard 失败,Dynamo 将最多重新编译该函数(或其一部分)torch._dynamo.config.recompile_limit 次。如果您的程序达到了缓存限制,您首先需要确定是哪个 guard 失败了,以及是您程序的哪一部分触发了它。使用 TORCH_TRACE/tlparseTORCH_LOGS=recompiles 来跟踪问题的根源,有关更多详细信息,请参阅 torch.compile 故障排除

为什么我在生产环境中进行重新编译?#

在某些情况下,您可能不希望在程序预热后出现意外的编译。例如,如果您在一个对延迟敏感的应用程序中提供生产流量。为此,TorchDynamo 提供了一种替代模式,其中使用先前编译的图,但不会生成新的图。

frozen_toy_example = dynamo.run(toy_example)
frozen_toy_example(torch.randn(10), torch.randn(10))

您是如何加快我的代码的?#

加速 PyTorch 代码主要有三种方式:

  1. 通过垂直融合实现内核融合,它融合了顺序操作以避免过多的读/写。例如,融合两个连续的余弦函数意味着您可以执行 1 次读、1 次写,而不是 2 次读、2 次写。水平融合:最简单的例子是批处理,其中单个矩阵与一系列示例相乘,但更通用的场景是分组 GEMM,其中一组矩阵乘法被调度在一起。

  2. 无序执行:编译器的一种通用优化,通过提前查看图中的确切数据依赖关系,我们可以决定执行节点的最佳时机以及哪些缓冲区可以重用。

  3. 自动工作负载放置:类似于无序执行的观点,但通过将图的节点与物理硬件或内存等资源进行匹配,我们可以设计一个合适的调度。

以上是加速 PyTorch 代码的通用原则,但不同的后端在优化方面会做出不同的权衡。例如,Inductor 首先负责融合所有可以融合的内容,然后才生成 Triton 内核。

Triton 还通过自动内存合并、内存管理和每个流多处理器内的调度来提供加速,并且已经设计用于处理分块计算。

然而,无论您使用哪个后端,最好采用基准测试和观察的方法,所以请尝试使用 PyTorch Profiler,直观地检查生成的内核,并自己尝试了解发生了什么。

为什么我没有看到加速?#

图中断#

使用 dynamo 时看不到所需加速的主要原因是过多的图拆分。那么什么是图拆分?

给定一个程序,例如

def some_fun(x):
    ...

torch.compile(some_fun)(x)
...

Torchdynamo 将尝试将 some_fun() 中的所有 torch/tensor 操作编译成一个单一的 FX 图,但它可能无法将所有内容捕获到一个图中。

对于 TorchDynamo 来说,一些图拆分的原因是无法克服的,例如调用 PyTorch 以外的 C 扩展对 TorchDynamo 是不可见的,并且可以执行任意操作,而 TorchDynamo 无法引入必要的 guards 来确保编译后的程序可以安全重用。

为了最大化性能,尽可能少地出现图拆分很重要。

识别图拆分的原因#

要识别程序中的所有图拆分及其原因,可以使用 torch._dynamo.explain。此工具会在提供的函数上运行 TorchDynamo,并聚合遇到的图拆分。这是一个用法示例:

import torch
import torch._dynamo as dynamo
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    print("woo")
    if b.sum() < 0:
        b = b * -1
    return x * b
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation)
"""
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
  Break Reason 1:
    Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
    User Stack:
      <FrameSummary file foo.py, line 5 in toy_example>
  Break Reason 2:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
  ...
Out Guards:
  ...
"""

要通过设置 fullgraph=True 来禁用 Python 回退,从而在遇到的第一个图拆分时抛出错误。如果您使用过基于 export 的编译器,这应该很熟悉。

def toy_example(a, b):
   ...

torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)

为什么我更改代码后没有重新编译?#

如果您通过设置 env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py 启用了动态形状,那么您的代码将不会在形状变化时重新编译。我们添加了对动态形状的支持,这可以在形状变化小于 2 倍的情况下避免重新编译。这在 CV 中的不同图像尺寸或 NLP 中的可变序列长度等场景中特别有用。在推理场景中,由于您会接收来自不同客户端应用程序的各种请求,因此通常无法提前知道批量大小是多少。

总的来说,TorchDynamo 会尽力避免不必要的重新编译。例如,如果 TorchDynamo 找到 3 个图,而您的更改只修改了一个图,那么只会重新编译该图。因此,避免可能缓慢的编译时间的另一个技巧是预热模型,在编译一次之后,后续的编译会快得多。冷启动编译时间仍然是我们可见跟踪的指标。

为什么我的结果不正确?#

通过设置环境变量 TORCHDYNAMO_REPRO_LEVEL=4 也可以最小化精度问题,它的工作方式类似于 git bisect 模型,一个完整的重现可能类似于 TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4。我们需要这个是因为下游编译器会生成代码,无论是 Triton 代码还是 C++ 后端,这些下游编译器的数值可能存在细微差别,但对您的训练稳定性产生巨大影响。因此,精度调试器对于我们检测代码生成或后端编译器中的 bug 非常有用。

如果您想确保随机数生成在 torch 和 triton 之间是相同的,您可以启用 torch._inductor.config.fallback_random = True

为什么我会遇到 OOM?#

Dynamo 仍处于 alpha 阶段,因此存在一些 OOM 的来源。如果您遇到 OOM,请尝试按顺序禁用以下配置,然后打开一个 GitHub issue,以便我们解决根本问题:1. 如果您正在使用动态形状,请尝试禁用它们,我们默认禁用了它们:env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py 2. Triton 中的 CUDA 图在 inductor 中是默认启用的,但禁用它们可能会缓解一些 OOM 问题:torch._inductor.config.triton.cudagraphs = False

torch.func 是否与 torch.compile 一起使用(用于 gradvmap 转换)?#

torch.func 转换应用于使用 torch.compile 的函数是可行的。

import torch

@torch.compile
def f(x):
    return torch.sin(x)

def g(x):
    return torch.grad(f)(x)

x = torch.randn(2, 3)
g(x)

在由 torch.compile 处理的函数内部调用 torch.func 转换#

使用 torch.compile 编译 torch.func.grad#

import torch

def wrapper_fn(x):
    return torch.func.grad(lambda x: x.sin().sum())(x)

x = torch.randn(3, 3, 3)
grad_x = torch.compile(wrapper_fn)(x)

使用 torch.compile 编译 torch.vmap#

import torch

def my_fn(x):
    return torch.vmap(lambda x: x.sum(1))(x)

x = torch.randn(3, 3, 3)
output = torch.compile(my_fn)(x)

编译不受支持的函数(逃生舱)#

对于其他转换,作为一种解决方法,请使用 torch._dynamo.allow_in_graph

allow_in_graph 是一个逃生舱。如果您的代码无法使用检查 Python 字节码的 torch.compile 工作,但您认为它可以通过符号跟踪方法(如 jax.jit)工作,则使用 allow_in_graph

通过使用 allow_in_graph 注释一个函数,您必须确保您的代码满足以下要求:

  • 您的函数中的所有输出仅取决于输入,而不取决于任何捕获的 Tensor。

  • 您的函数是函数式的。也就是说,它不会修改任何状态。这可能会放宽;实际上,我们支持从外部看起来是函数式的函数:它们可能有就地 PyTorch 操作,但不能修改全局状态或函数的输入。

  • 您的函数不会引发数据相关的错误。

import torch

@torch.compile
def f(x):
    return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)

x = torch.randn(2, 3)
f(x)

一个常见的陷阱是使用 allow_in_graph 来注释调用 nn.Module 的函数。这是因为输出现在取决于 nn.Module 的参数。要使此工作,请使用 torch.func.functional_call 来提取模块状态。

NumPy 是否与 torch.compile 一起使用?#

从 2.1 版本开始,torch.compile 可以理解原生 NumPy 程序(作用于 NumPy 数组)以及混合 PyTorch-NumPy 程序(通过 x.numpy()torch.from_numpy 和相关函数在 PyTorch 和 NumPy 之间进行转换)。

torch.compile 支持哪些 NumPy 功能?#

torch.compile 中的 NumPy 遵循 NumPy 2.0 预发布版本。

通常,torch.compile 能够跟踪大多数 NumPy 结构,当它无法跟踪时,它会回退到 eager 模式,并让 NumPy 执行该代码片段。即便如此,在少数功能上,torch.compile 的语义与 NumPy 的略有不同:

  • NumPy 标量:我们将其建模为 0 维数组。也就是说,在 torch.compile 下,np.float32(3) 返回一个 0 维数组。为了避免图拆分,最好使用这个 0 维数组。如果这破坏了您的代码,您可以通过将 NumPy 标量转换为相关的 Python 标量类型(bool/int/float)来解决此问题。

  • 负步长:np.flip 和带有负步长的切片会返回副本。

  • 类型提升:NumPy 的类型提升将在 NumPy 2.0 中发生变化。新规则在 NEP 50 中进行了描述。torch.compile 实现的是 NEP 50,而不是当前即将弃用的规则。

  • {tril,triu}_indices_from/{tril,triu}_indices 返回数组而不是数组元组。

还有其他一些功能我们不支持跟踪,并且我们优雅地回退到 NumPy 来执行它们:

  • 非数字 dtype,如日期时间、字符串、字符、void、结构化 dtype 和 recarrays。

  • 长 dtype np.float128/np.complex256 和一些无符号 dtype np.uint16/np.uint32/np.uint64

  • ndarray 子类。

  • 掩码数组。

  • 晦涩的 ufunc 机制,如 axes=[(n,k),(k,m)->(n,m)] 和 ufunc 方法(例如 np.add.reduce)。

  • complex64/complex128 数组进行排序/排序。

  • NumPy np.poly1dnp.polynomial

  • 具有 2 个或更多返回值的函数中的位置参数 out1, out2out=tuple 是有效的)。

  • __array_function____array_interface____array_wrap__

  • ndarray.ctypes 属性。

我可以使用 torch.compile 编译 NumPy 代码吗?#

当然可以!torch.compile 可以原生理解 NumPy 代码,并将其视为 PyTorch 代码。要做到这一点,只需用 torch.compile 装饰器包装 NumPy 代码即可。

import torch
import numpy as np

@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)

使用环境变量 TORCH_LOGS=output_code 执行此示例,我们可以看到 torch.compile 能够将乘法和求和融合到一个 C++ 内核中。它还可以使用 OpenMP 并行执行它们(原生 NumPy 是单线程的)。这可以轻松地使您的 NumPy 代码比原来快 n 倍,其中 n 是您处理器上的核心数!

以这种方式跟踪 NumPy 代码也支持编译代码中的图拆分。

我可以在 CUDA 上执行 NumPy 代码并通过 torch.compile 计算梯度吗?#

是的,您可以!要做到这一点,您只需在 torch.device("cuda") 上下文中执行代码。考虑以下示例:

import torch
import numpy as np

@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
with torch.device("cuda"):
    Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)

在此示例中,numpy_fn 将在 CUDA 上执行。为了实现这一点,torch.compile 会自动将 XY 从 CPU 移动到 CUDA,然后将结果 Z 从 CUDA 移动到 CPU。如果我们多次在同一程序运行中执行此函数,我们可能希望避免所有这些相当昂贵的内存副本。要做到这一点,我们只需调整我们的 numpy_fn,使其能够接受 CUDA Tensor 并返回 Tensor。我们可以通过使用 torch.compiler.wrap_numpy 来实现这一点。

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

在这里,我们显式地在 CUDA 内存中创建 Tensor,并将它们传递给函数,该函数在 CUDA 设备上执行所有计算。wrap_numpy 负责将任何 torch.Tensor 输入标记为具有 np.ndarray 语义的输入,在 torch.compile 级别。在编译器中标记 Tensor 是一个非常便宜的操作,因此在运行时不会发生数据复制或数据移动。

使用此装饰器,我们还可以区分 NumPy 代码!

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)))

X = torch.randn(1024, 64, device="cuda", requires_grad=True)
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
Z.backward()
# X.grad now holds the gradient of the computation
print(X.grad)

我们一直在使用 fullgraph=True,因为在这种情况下,图拆分是有问题的。当发生图拆分时,我们需要具体化 NumPy 数组。由于 NumPy 数组没有 devicerequires_grad 的概念,因此在图拆分期间会丢失此信息。

我们无法通过图拆分传播梯度,因为图拆分代码可能会执行任意代码,而这些代码不知道如何进行区分。另一方面,在 CUDA 执行的情况下,我们可以通过使用 torch.device("cuda") 上下管理器来解决这个问题,就像我们在第一个示例中所做的那样。

@torch.compile
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    prod = X[:, :, None] * Y[:, None, :]
    print("oops, a graph break!")
    return np.sum(prod, axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")

with torch.device("cuda"):
    Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

在图拆分期间,中间 Tensor 仍然需要移动到 CPU,但在图拆分后恢复跟踪时,图的其余部分仍然在 CUDA 上进行跟踪。考虑到这种 CUDA <> CPU 和 CPU <> CUDA 的移动,图拆分在 NumPy 上下文中相当昂贵,应该避免,但至少它们允许跟踪复杂的代码片段。

如何在 torch.compile 下调试 NumPy 代码?#

调试 JIT 编译代码具有挑战性,考虑到现代编译器的复杂性以及它们引发的令人畏惧的错误。 torch.compile 故障排除文档 包含一些解决此任务的技巧。

如果以上不足以确定问题的根源,我们仍然可以使用一些其他的 NumPy 特定工具。我们可以通过禁用对 NumPy 函数的跟踪来区分 PyTorch 代码中是否存在 bug:

from torch._dynamo import config
config.trace_numpy = False

如果 bug 出在跟踪的 NumPy 代码中,我们可以通过导入 import torch._numpy as np 来使用 PyTorch 作为后端,以 eager 模式(无需 torch.compile)执行 NumPy 代码。这仅应用于调试目的,绝不能替代 PyTorch API,因为它性能差很多,并且作为私有 API,可能会在未通知的情况下更改。无论如何,torch._numpy 是 PyTorch 的 NumPy Python 实现,它被 torch.compile 内部用于将 NumPy 代码转换为 PyTorch 代码。它相对容易阅读和修改,因此如果您发现任何 bug,请随时提交 PR 进行修复或直接打开 issue。

如果导入 torch._numpy as np 后程序可以正常工作,那么 bug 很可能在 TorchDynamo 中。如果是这种情况,请随时打开一个 issue,并附带一个 最小可重现示例

我使用 torch.compile 编译了一些 NumPy 代码,但没有看到任何加速。#

最好的起点是 关于如何调试此类 torch.compile 问题的通用建议教程

一些图拆分可能是由于使用了不受支持的功能。请参阅 torch.compile 支持哪些 NumPy 功能?。更一般地说,牢记一些广泛使用的 NumPy 功能与编译器配合不佳。例如,就地修改使编译器中的推理变得困难,并且通常会产生比其异地对应版本差的性能。因此,最好避免它们。同样适用于使用 out= 参数。相反,请优先使用异地操作,让 torch.compile 优化内存使用。同样的情况也适用于数据依赖操作,如通过布尔掩码进行的掩码索引,或数据依赖控制流,如 ifwhile 结构。

用于细粒度跟踪的 API 是什么?#

在某些情况下,您可能需要将代码的小部分排除在 torch.compile 编译之外。本节提供了一些答案,您可以在 用于细粒度跟踪的 TorchDynamo API 中找到更多信息。

如何对函数进行图拆分?#

对函数进行图拆分不足以充分表达您希望 PyTorch 执行的操作。您需要更具体地说明您的用例。一些最常见的用例可能需要您考虑:

  • 如果您想禁用此函数帧及其递归调用的帧的编译,请使用 torch._dynamo.disable

  • 如果您希望某个特定的操作符(例如 fbgemm)使用 eager 模式,请使用 torch._dynamo.disallow_in_graph

一些不常见的用例包括:

  • 如果您想禁用函数帧上的 TorchDynamo,但在递归调用的帧上重新启用它——请使用 torch._dynamo.disable(recursive=False)

  • 如果您想阻止内联函数帧——请在您想要阻止内联的函数开头使用 torch._dynamo.graph_break

torch._dynamo.disabletorch._dynamo.disallow_in_graph 之间有什么区别?#

Disallow-in-graph 在操作符级别工作,或者更具体地说,在 TorchDynamo 提取的图中看到的操作符级别工作。

Disable 在函数帧级别工作,并决定 TorchDynamo 是否应该查看函数帧。

torch._dynamo.disabletorch._dynamo_skip 之间有什么区别?#

注意

torch._dynamo_skip 已弃用。

您最有可能需要 torch._dynamo.disable。但在不太可能的情况下,您可能需要更精细地控制。假设您只想禁用 a_fn 函数上的跟踪,但希望在 aa_fnab_fn 中继续跟踪。下图演示了此用例:

diagram of torch.compile + disable(a_fn, recursive=False)

在这种情况下,您可以使用 torch._dynamo.disable(recursive=False)。在以前的版本中,此功能由 torch._dynamo.skip 提供。现在,这由 torch._dynamo.disable 中的 recursive 标志支持。