评价此页

CUDAGraph Trees#

创建于:2023年5月19日 | 最后更新于:2025年6月10日

背景#

CUDAGraph#

关于 CUDAGraph 的更多背景介绍,请阅读 使用 CUDAGraphs 加速 PyTorch

CUDA Graphs 在 CUDA 10 中首次亮相,它允许将一系列 CUDA 内核定义并封装为一个单元,即操作图,而不是一系列单独启动的操作。它提供了一种通过单个 CPU 操作来启动多个 GPU 操作的机制,从而减少了启动开销。

CUDA Graphs 可以带来显著的加速,尤其适用于 CPU 开销大或计算量小的模型。它有一些限制,要求相同的内核以相同的参数和依赖关系、内存地址运行。

  • 不支持控制流

  • 触发主机到设备同步的内核(例如 `.item()`)会报错

  • 内核的所有输入参数都固定为录制时的值

  • CUDA 内存地址是固定的,但这些地址的内存值可以改变

  • 不允许本质上的 CPU 操作或 CPU 侧副作用

PyTorch CUDAGraph 集成#

PyTorch 提供了一个围绕 CUDAGraphs 的便捷封装,它处理了与 PyTorch 缓存分配器的一些棘手交互。

缓存分配器为所有新分配使用一个单独的内存池。在 CUDAGraph 录制期间,内存的核算、分配和释放与即时执行期间完全相同。在重放时,仅调用内核,分配器没有任何变化。在初始录制之后,分配器不知道哪些内存正在被用户程序积极使用。

如果在即时分配和 cudagraph 分配之间使用单独的内存池,可能会增加程序的内存占用,如果两者都有大量内存分配的话。

创建可图形化函数#

Make Graphed Callables 是一个 PyTorch 抽象,用于在多个函数调用之间共享单个内存池。Graphed Callables 利用了缓存分配器在 CUDAGraph 录制时精确核算内存的事实,可以在单独的 CUDA Graph 录制之间安全地共享内存。在每次调用中,输出都被保留为活动的内存,防止一个函数调用覆盖另一个函数的活动内存。Graphed Callables 只能按单一顺序调用;第一个运行的内存地址会被“烧录”到第二个,依此类推。

TorchDynamo 之前的 CUDA Graphs 集成#

运行 cudagraph_trees=False 不会在单独的图捕获之间重用内存,这可能导致严重的内存回归。即使对于没有图中断的模型,也存在这个问题。前向传播和后向传播是单独的图捕获,因此前向和后向的内存池不共享。特别是,前向传播中保存的激活的内存无法在后向传播中回收。

CUDAGraph Trees 集成#

与 Graph Callables 类似,CUDA Graph Trees 在所有图捕获中使用单个内存池。但是,它不要求单一的调用顺序,而是创建单独的 CUDA Graph 捕获树。让我们看一个说明性示例。

@torch.compile(mode="reduce-overhead")
def foo(x):
    # GRAPH 1
    y = x * x * x
    # graph break triggered here
    if y.sum() > 0:
        # GRAPH 2
        z = y ** y
    else:
        # GRAPH 3
        z = (y.abs() ** y.abs())
    torch._dynamo.graph_break()
    # GRAPH 4
    return z * torch.rand_like(z)

# the first run warms up each graph, which does things like CuBlas or Triton benchmarking
foo(torch.arange(0, 10, device="cuda"))
# The second run does a CUDA Graph recording, and replays it
foo(torch.arange(0, 10, device="cuda"))
# Finally we hit the optimized, CUDA Graph replay path
foo(torch.arange(0, 10, device="cuda"))

在此示例中,函数有两条不同的路径:1 -> 2 -> 4,或 1 -> 3 -> 4。

我们通过构建 CUDA Graph 录制带(例如 1 -> 2 -> 4)来共享各个录制之间的所有内存。我们添加不变性以确保内存始终位于录制时的位置,并且用户程序中没有可能被覆盖的活动张量。

  • 仍然适用 CUDA Graphs 的约束:必须以相同的参数(静态大小、地址等)调用相同的内核。

  • 在录制和回放之间必须观察到相同的内存模式:如果在录制期间一个图的张量输出发生在另一个图之后,那么在回放期间也必须如此。

  • CUDA 池中的活动内存会导致两个录制之间产生依赖关系。

  • 这些录制只能按单一顺序调用:1 -> 2 -> 4。

所有内存都在单个内存池中共享,因此与即时执行相比,没有额外的内存开销。那么,如果我们遇到一条新路径并运行图 3 会发生什么?

图 1 被回放,然后我们遇到图 3,我们还没有录制它。在图回放时,私有内存池不会更新,因此 y 没有在分配器中反映出来。如果不小心,我们会覆盖它。为了支持回放其他图后重用相同的内存池,我们将内存池的副本恢复到图 1 结束时的状态。现在我们的活动张量反映在缓存分配器中,我们可以安全地运行一个新图。

首先,我们将命中优化后的 CUDAGraph.replay() 路径,这已经在图 1 中录制过了。然后我们将命中图 3。和以前一样,在录制之前我们需要预热图一次。在预热运行时,内存地址不是固定的,因此图 4 也会回退到 inductor,非 cudagraph 调用。

第二次遇到图 3 时,我们已预热并准备好录制。我们录制图 3,然后再次录制图 4,因为输入内存地址已更改。这就创建了一个 CUDA Graph 录制树。一个 CUDA Graph Tree!

  1
 / \\
2   3
 \\   \\
  4   4

输入突变支持#

输入突变函数是指执行原地写入输入张量的函数,如下所示:

def foo(x, y):
    # mutates input x
    x.add_(1)
    return x + y

输入突变函数通常给 CUDAGraph Trees 带来挑战。由于 CUDAGraph 对 CUDA 内存地址的要求是静态的,对于每个输入张量 x,CUDAGraph Trees 可能会分配一个静态内存地址 x'。在执行期间,CUDAGraph Trees 首先将输入张量 x 复制到静态内存地址 x',然后回放录制的 CUDAGraph。对于输入突变函数,x' 会被原地更新,这不会反映在输入张量 x 上,因为 x 和 x' 位于不同的 CUDA 内存地址。

仔细查看输入突变函数会发现,有三种类型的输入:

  • 来自 eager 的输入:我们假定这些张量在每次执行时输入张量地址都会变化。因为 cudagraphs 会固定内存地址,我们需要在图录制和执行之前将这些输入复制到静态地址张量。

  • 参数和缓冲区:我们假定(并通过运行时检查)这些张量在每次执行时都具有相同的张量地址。我们不需要复制它们的内​​容,因为录制的内存地址将与执行的内存地址相同。

  • 来自 CUDAGraph 树的先前输出的张量:由于 cudagraph 的输出张量地址是固定的,如果我们先运行 CUDAGraph1,然后运行 CUDAGraph2,那么从 CUDAGraph1 进入 CUDAGraph2 的输入将具有固定的内存地址。这些输入,如参数和缓冲区,不需要复制到静态地址张量。我们检查以确保这些输入在运行时是稳定的,如果它们不稳定,我们将重新记录。

CUDAGraph 树支持对参数、缓冲区以及来自 CUDAGraph 树的先前输出的张量进行输入变异。对于来自 eager 的输入的变异,CUDAGraph 树将无 CUDAGraph 地运行函数,并发出因输入变异而跳过的日志。以下示例展示了 CUDAGraph 树对来自 CUDAGraph 树的先前输出的张量的支持。

import torch

@torch.compile(mode="reduce-overhead")
def foo(x):
    return x + 1

@torch.compile(mode="reduce-overhead")
def mut(x):
    return x.add_(2)

# Enable input mutation support
torch._inductor.config.triton.cudagraph_support_input_mutation = True

for i in range(3):
    torch.compiler.cudagraph_mark_step_begin()
    inp = torch.rand([4], device="cuda")

    # CUDAGraph is applied since `foo` does not mutate `inp`
    tmp = foo(inp)
    # Although `mut` mutates `tmp`, which is an output of a CUDAGraph
    # managed function. So CUDAGraph is still applied.
    mut(tmp)


torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")

tmp = foo(inp)
# While `tmp` is a CUDAGraph Tree managed function's output, `tmp.clone()`
# is not. So CUDAGraph is not applied to `mut` and there is a log
# `skipping cudagraphs due to mutated inputs`
mut(tmp.clone())

要为变异来自 eager 的输入的函数启用 CUDAGraph 树,请重写该函数以避免输入变异。

注意
通过设置 torch._inductor.config.cudagraph_support_input_mutation = True 来为“减少开销”模式启用输入变异支持。

动态形状支持#

动态形状意味着输入张量在函数调用之间具有不同的形状。由于 CUDAGraph 需要固定的张量地址,CUDAGraph 树会为输入张量的每种唯一形状重新记录 CUDAGraph。这会导致单个 inductor 图有多个 CUDAGraph。当形状有限时(例如,推理中的批量大小),重新记录 CUDAGraph 是有利的。但是,如果输入张量形状频繁更改,甚至在每次调用时更改,重新记录 CUDAGraph 可能不划算。Nvidia 在 CUDAGraph 中每启动一个内核使用 64 KB 设备内存,直到 CUDA 12.4 和驱动版本 550+。对于许多 CUDAGraph 重新记录,此内存成本可能很高。

对于输入张量形状频繁更改的函数,我们建议将输入张量填充到几个固定的张量形状,以继续享受 CUDAGraph 的优势。此外,设置 torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True 可以跳过具有动态形状输入的 cudagraphing 函数,只对具有静态输入张量形状的函数进行 cudagraphing。

NCCL 支持#

CUDAGraph 树支持带有 nccl 运算符的函数。虽然 CUDAGraph 树执行 CUDAGraph 的每设备记录,但 NCCL 支持允许跨设备通信。

@torch.compile(mode="reduce-overhead")
def func(x):
    y = x * x
    y = torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM)
    x = torch.nn.functional.silu(x)
    return x * y

跳过 CUDAGraph 的原因#

由于 CUDAGraph 有诸如静态输入张量地址和不支持 CPU 运算符等要求,CUDAGraph 树会检查函数是否满足这些要求,并在必要时跳过 CUDAGraph。在此,我们列出跳过 CUDAGraph 的常见原因。

  • 输入变异:CUDAGraph 树会跳过就地修改 eager 输入的函数。就地修改参数和缓冲区,或 CUDAGraph 树管理函数的输出张量仍然受支持。请参阅输入变异支持部分了解更多详细信息。

  • CPU 运算符:包含 CPU 运算符的函数将被跳过。请将函数拆分为多个函数,并将 CUDAGraph 树应用于仅包含 GPU 运算符的函数。

  • 多设备运算符:如果函数包含跨多个设备的运算符,则该函数将被跳过。目前,CUDAGraph 是在每个设备的基础上应用的。请使用 NCCL 等支持的库进行跨设备通信。请参阅NCCL 支持部分了解更多详细信息。

  • 释放未备份的符号:释放未备份的符号通常发生在 动态形状期间。CUDAGraph 树目前为每种唯一的输入张量形状记录一个 CUDAGraph。请参阅动态形状支持了解更多详细信息。

  • 不兼容的运算符:如果函数包含不兼容的运算符,CUDAGraph 树将跳过该函数。请在函数中将这些运算符替换为受支持的运算符。我们列出了不兼容运算符的详尽列表

aten._fused_moving_avg_obs_fq_helper.default
aten._fused_moving_avg_obs_fq_helper_functional.default
aten.multinomial.default
fbgemm.dense_to_jagged.default
fbgemm.jagged_to_padded_dense.default
run_and_save_rng_state
run_with_rng_state
aten._local_scalar_dense
aten._assert_scalar

torch.are_deterministic_algorithms_enabled() 启用时,以下运算符是不兼容的。

aten._fused_moving_avg_obs_fq_helper.default
aten._fused_moving_avg_obs_fq_helper_functional.default
aten.multinomial.default
fbgemm.dense_to_jagged.default
fbgemm.jagged_to_padded_dense.default
run_and_save_rng_state
run_with_rng_state
aten._local_scalar_dense
aten._assert_scalar

局限性#

由于 CUDA Graph 固定了内存地址,CUDA Graph 在处理先前调用的实时张量方面没有很好的方法。

假设我们使用以下代码来基准测试推理运行

import torch

@torch.compile(mode="reduce-overhead")
def my_model(x):
    y = torch.matmul(x, x)
    return y

x = torch.randn(10, 10, device="cuda")
y1 = my_model(x)
y2 = my_model(x)
print(y1)
# RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.

在单独的 CUDA Graph 实现中,第一次调用的输出将被第二次调用覆盖。在 CUDAGraph 树中,我们不希望在迭代之间添加意外的依赖关系,这会导致我们无法命中热路径,也不希望我们过早释放先前调用的内存。我们的启发式方法是在推理中,我们为 torch.compile 的每次调用启动一个新迭代,在训练中,只要没有待处理的尚未调用的 backward,我们也这样做。如果这些启发式方法不正确,您可以使用 torch.compiler.mark_step_begin() 标记新迭代的开始,或者在开始下一次运行之前克隆先前迭代的张量(在 torch.compile 之外)。

比较#

易出错点

单独的 CudaGraph

CUDAGraph 树

内存可能会增加

每次图编译时(新尺寸等)

如果您还在运行非 cudagraph 内存

记录

在图的任何新调用上

将在您程序的任何新的、唯一的路径上重新记录

易出错点

一个图的调用将覆盖先前的调用

无法在模型的单独运行之间持久化内存 - 一次训练循环训练,或一次推理运行