CUDAGraph 树#
创建时间:2023 年 5 月 19 日 | 最后更新时间:2025 年 7 月 30 日
背景#
CUDAGraph#
有关 CUDAGraph 的更详细背景信息,请阅读 使用 CUDAGraph 加速 PyTorch。
CUDA Graphs 于 CUDA 10 首次亮相,它允许将一系列 CUDA 内核定义并封装为一个单元,即操作图,而不是一系列单独启动的操作。它提供了一种机制,通过一次 CPU 操作即可启动多个 GPU 操作,从而降低了启动开销。
CUDA Graphs 可以带来显著的加速,尤其是对于 CPU 开销高或计算量小的模型。由于要求相同的内核以相同的参数和依赖项以及内存地址运行,因此存在一些限制。
无法进行控制流
触发主机到设备同步的内核(例如 `.item()`)会报错
内核的所有输入参数都固定为其记录时的值
CUDA 内存地址是固定的,但这些地址处的内存值可以更改
不包含核心 CPU 操作或 CPU 端副作用
PyTorch CUDAGraph 集成#
PyTorch 提供了 CUDAGraph 的一个 方便的包装器,该包装器处理了与 PyTorch 缓存分配器的一些棘手交互。
缓存分配器使用独立的内存池来处理所有新分配。在 CUDAGraph 记录期间,内存的核算、分配和释放与在即时运行期间完全相同。在回放时,只调用内核,分配器没有任何变化。在初始记录之后,分配器不知道哪些内存正在用户程序中被主动使用。
在即时分配和 cudagraph 分配之间使用独立的内存池,如果两者都有大量内存分配,可能会增加程序的内存占用。
创建图表可调用对象#
Make Graphed Callables 是一个 PyTorch 抽象,用于在一系列可调用对象之间共享单个内存池。Graphed Callables 利用了 CUDA Graph 记录期间缓存分配器精确核算内存的事实,以安全地在独立的 CUDA Graph 记录之间共享内存。在每次调用中,输出都会被保留为活动内存,防止一个可调用对象覆盖另一个的可活动内存。Graphed Callables 只能按单一顺序调用;第一个运行的内存地址将被烧录到第二个,依此类推。
TorchDynamo 之前的 CUDA Graphs 集成#
使用 `cudagraph_trees=False` 运行不会在独立的图捕获之间重用内存,这可能导致内存回归。即使是对于没有图中断的模型,这也会有问题。前向和后向是独立的图捕获,因此前向和后向的内存池不共享。特别是,在前向中保存的激活的内存无法在后向中回收。
CUDAGraph 树集成#
与 Graphed Callables 类似,CUDA Graph Trees 在所有图捕获之间使用单个内存池。但是,CUDA Graph Trees 不要求单一的调用序列,而是创建独立的 CUDA Graph 捕获树。让我们看一个说明性示例
@torch.compile(mode="reduce-overhead")
def foo(x):
# GRAPH 1
y = x * x * x
# graph break triggered here
if y.sum() > 0:
# GRAPH 2
z = y ** y
else:
# GRAPH 3
z = (y.abs() ** y.abs())
torch._dynamo.graph_break()
# GRAPH 4
return z * torch.rand_like(z)
# the first run warms up each graph, which does things like CuBlas or Triton benchmarking
foo(torch.arange(0, 10, device="cuda"))
# The second run does a CUDA Graph recording, and replays it
foo(torch.arange(0, 10, device="cuda"))
# Finally we hit the optimized, CUDA Graph replay path
foo(torch.arange(0, 10, device="cuda"))
在此示例中,我们通过函数有两条独立的路径:1 -> 2 -> 4,或 1 -> 3 -> 4。
我们通过构建一个 CUDA Graph 记录的 tape(在此实例中是 1 -> 2 -> 4)来在独立的记录之间共享单个内存池中的所有内存。我们添加了不变量来确保内存始终位于记录时的相同位置,并且用户程序中没有可能被覆盖的活动张量。
CUDA Graphs 的相同约束适用:必须使用相同的参数(静态大小、地址等)调用相同的内核
在记录和回放之间必须观察到相同的内存模式:如果在记录期间一个图的张量输出在另一个图之后死亡,那么在回放期间也必须如此。
CUDA 池中的活动内存会强制两个记录之间产生依赖关系
这些记录只能按单一顺序调用 1 -> 2 -> 4
所有内存都共享在一个内存池中,因此与即时运行相比没有额外的内存开销。那么,如果我们遇到一条新路径并运行图 3 会发生什么?
图 1 被回放,然后我们遇到图 3,这之前我们还没有记录。在图回放时,私有内存池不会更新,因此 y 不会反映在分配器中。如果不小心,我们会覆盖它。为了支持在回放其他图后重用相同的内存池,我们将内存池回滚到图 1 结束时的状态。现在我们的活动张量已反映在缓存分配器中,我们可以安全地运行新图了。
首先,我们将命中优化过的、已在图 1 中记录的 CUDAGraph.replay() 路径。然后我们将命中图 3。和以前一样,我们需要在记录之前预热一次图。在预热运行中,内存地址不是固定的,因此图 4 也将回退到 inductor 的非 cudagraph 调用。
第二次遇到图 3 时,我们已经预热并准备好记录。我们记录图 3,然后再次记录图 4,因为输入内存地址已更改。这创建了一个 CUDA Graph 记录的树。一个 CUDA Graph Tree!
1
/ \\
2 3
\\ \\
4 4
输入变异支持#
输入变异函数是指在输入张量上进行原地写入的函数,如下例所示
def foo(x, y):
# mutates input x
x.add_(1)
return x + y
输入变异函数通常给 CUDAGraph Trees 带来挑战。由于 CUDAGraph 需要静态 CUDA 内存地址,对于每个输入张量 x,CUDAGraph Trees 可能会分配一个静态内存地址 x'。在执行期间,CUDAGraph Trees 首先将输入张量 x 复制到静态内存地址 x',然后回放记录的 CUDAGraph。对于输入变异函数,x' 是原地更新的,这不会反映在输入张量 x 上,因为 x 和 x' 位于不同的 CUDA 内存地址。
仔细查看输入变异函数会发现有三种类型的输入
来自即时运行的输入:我们假设这些张量在每次执行时输入张量地址都会发生变化。由于 cudagraph 会冻结内存地址,因此我们需要在图记录和执行之前将这些输入复制到一个静态地址张量。
参数和缓冲区:我们假设(并在运行时检查)这些张量在每次执行时都有相同的张量地址。我们不需要复制它们的内容,因为记录的内存地址将与执行的内存地址相同。
来自 CUDAGraph Trees 的先前输出的张量:由于 cudagraph 的输出张量地址是固定的,如果我们运行 CUDAGraph1,然后运行 CUDAGraph2,从 CUDAGraph1 输入到 CUDAGraph2 的输入将具有固定的内存地址。这些输入,如参数和缓冲区,不需要复制到静态地址张量。我们检查以确保这些输入在运行时稳定,如果不稳定,我们将重新记录。
CUDAGraph Trees 支持对参数和缓冲区以及来自 CUDAGraph Trees 的先前输出的张量进行输入变异。对于来自即时运行输入的变异,CUDAGraph Trees 将在没有 CUDAGraph 的情况下运行函数,并发出“因输入变异而跳过”的日志。以下示例显示了 CUDAGraph Trees 对来自 CUDAGraph Trees 的先前输出的张量的支持。
import torch
@torch.compile(mode="reduce-overhead")
def foo(x):
return x + 1
@torch.compile(mode="reduce-overhead")
def mut(x):
return x.add_(2)
# Enable input mutation support
torch._inductor.config.triton.cudagraph_support_input_mutation = True
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
# CUDAGraph is applied since `foo` does not mutate `inp`
tmp = foo(inp)
# Although `mut` mutates `tmp`, which is an output of a CUDAGraph
# managed function. So CUDAGraph is still applied.
mut(tmp)
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
tmp = foo(inp)
# While `tmp` is a CUDAGraph Tree managed function's output, `tmp.clone()`
# is not. So CUDAGraph is not applied to `mut` and there is a log
# `skipping cudagraphs due to mutated inputs`
mut(tmp.clone())
要为修改来自即时运行输入的函数启用 CUDAGraph Trees,请重写该函数以避免输入变异。
注意
通过设置 torch._inductor.config.cudagraph_support_input_mutation = True 来为“减少开销”模式启用输入变异支持。
动态形状支持#
动态形状意味着输入张量在函数调用之间具有不同的形状。由于 CUDAGraph 需要固定的张量地址,CUDAGraph Trees 会为输入张量的每种唯一形状重新记录 CUDAGraph。这会导致一个 inductor graph 有多个 CUDAGraph。当形状有限时(例如,推理中的批处理大小),重新记录 CUDAGraph 是有利的。然而,如果输入张量形状频繁更改,甚至在每次调用时都更改,重新记录 CUDAGraph 可能不划算。Nvidia 在 CUDA 12.4 和驱动程序版本 550+ 之前,每次启动 CUDAGraph 时使用 64 KB 设备内存。对于许多 CUDAGraph 的重新记录,这种内存成本可能非常可观。
对于输入张量形状频繁更改的函数,我们建议将输入张量填充到少数固定张量形状,以仍然享受 CUDAGraph 的好处。此外,设置 torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True 可以跳过具有动态形状输入的函数进行 cudagraphing,只对具有静态输入张量形状的函数进行 cudagraphing。
NCCL 支持#
CUDAGraph Trees 支持包含 nccl 运算符的函数。虽然 CUDAGraph Trees 对 CUDAGraph 进行每个设备记录,但 NCCL 支持允许跨设备通信。
@torch.compile(mode="reduce-overhead")
def func(x):
y = x * x
y = torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM)
x = torch.nn.functional.silu(x)
return x * y
跳过 CUDAGraph 的原因#
由于 CUDAGraph 有静态输入张量地址的要求,并且不支持 CPU 运算符,CUDAGraph Trees 会检查函数是否满足这些要求,并在必要时跳过 CUDAGraph。在此,我们列出了跳过 CUDAGraph 的常见原因。
输入变异:CUDAGraph Trees 会跳过原地变异即时输入的函数。原地变异参数和缓冲区,或来自 CUDAGraph Tree 管理函数的输出张量仍然受支持。请参阅“输入变异支持”部分了解更多详细信息。
CPU 运算符:包含 CPU 运算符的函数将被跳过。请将函数拆分为多个函数,并将 CUDAGraph Trees 应用于仅包含 GPU 运算符的函数。
多设备运算符:如果函数包含多个设备上的运算符,则会跳过该函数。目前,CUDAGraph 是按设备应用的。请使用 NCCL 等支持的库进行跨设备通信。请参阅“NCCL 支持”部分了解更多详细信息。
释放未支持的符号:释放未支持的符号通常发生在 动态形状期间。CUDAGraph Trees 目前为每种唯一的输入张量形状记录一个 CUDAGraph。请参阅“动态形状支持”部分了解更多详细信息。
CUDAGraph 不安全的自定义运算符:某些自定义运算符可能包含 cudagraph 不安全的运算符,这会导致 cudagraph 被跳过。请参阅“CUDAGraph 不安全的自定义运算符”部分了解更多详细信息。
不兼容的运算符:如果函数包含不兼容的运算符,CUDAGraph Trees 将跳过该函数。请将函数中的这些运算符替换为支持的运算符。我们显示了不兼容运算符的详尽列表
aten._fused_moving_avg_obs_fq_helper.default
aten._fused_moving_avg_obs_fq_helper_functional.default
aten.multinomial.default
fbgemm.dense_to_jagged.default
fbgemm.jagged_to_padded_dense.default
run_and_save_rng_state
run_with_rng_state
aten._local_scalar_dense
aten._assert_scalar
当 torch.are_deterministic_algorithms_enabled() 时,以下运算符不兼容。
aten._fused_moving_avg_obs_fq_helper.default
aten._fused_moving_avg_obs_fq_helper_functional.default
aten.multinomial.default
fbgemm.dense_to_jagged.default
fbgemm.jagged_to_padded_dense.default
run_and_save_rng_state
run_with_rng_state
aten._local_scalar_dense
aten._assert_scalar
CUDAGraph 不安全的自定义运算符#
默认情况下,自定义运算符被假定为对 CUDAGraph 是安全的。但是,某些自定义运算符可能包含不支持的运算符,例如 CPU 运算符。由于编译器将自定义运算符视为黑盒,用户必须通过设置 `torch._C.Tag.cudagraph_unsafe` 标签显式将这些运算符标记为对 CUDAGraph 不安全,如下例所示。当函数包含 cudagraph 不安全的自定义运算符时,除非启用了“CUDAGraph 分区”,否则 CUDAGraph 将跳过它。
@torch.library.custom_op(
"mylib::modify",
mutates_args=(),
tags=(torch._C.Tag.cudagraph_unsafe,),
)
def modify(pic: torch.Tensor) -> torch.Tensor:
pic1 = pic + 1
pic1_cpu = (pic1.cpu() + 1) * 2
return pic1_cpu.cuda() + pic
@modify.register_fake
def _(pic):
return torch.empty_like(pic)
CUDAGraph 分区#
如前所述,CUDAGraph 不支持某些运算符(例如 CPU 运算符),这可能会限制其采用。CUDAGraph 分区是一种编译器解决方案,它自动分离这些运算符,重新排序运算符以减少分区数量,并单独将 CUDAGraph 应用于每个分区。请设置 `torch._inductor.config.graph_partition=True` 来启用 CUDAGraph 分区。
考虑以下示例,其中 `x` 和 `y` 是 GPU 输入,但 `y_cpu` 是 CPU 张量。没有图分区,此函数必须因 CPU 运算符而被跳过。通过图分区,CPU 运算符被分离出来,其余的 GPU 运算符被 cudagraphified,从而产生两个独立的 CUDAGraph。
def f(x, y):
x1 = x + 1
y1 = y + 1
y_cpu = y1.cpu() + 1
z = x @ y
return x1 + y1 + z + y_cpu.cuda()
目前,CUDAGraph 分区支持分离以下类型的运算符
非 GPU 运算符:常见的例子包括 CPU 张量上的计算。
设备复制运算符:设备之间的数据传输,例如上面示例中的 `y1.cpu()`。
控制流运算符:控制流运算符已被分离,因为 CUDAGraph 尚不支持它们。
CUDAGraph 不安全的自定义运算符:标记有 `torch._C.Tag.cudagraph_unsafe` 的自定义运算符已被分离。请参阅“CUDAGraph 不安全的自定义运算符”部分了解详细信息。
未支持的 Symints:请参阅“动态形状支持”部分了解更多信息。
限制#
由于 CUDA Graph 固定了内存地址,CUDA Graphs 在处理前一次调用的活动张量方面没有很好的方法。
假设我们正在使用以下代码运行推理进行基准测试
import torch
@torch.compile(mode="reduce-overhead")
def my_model(x):
y = torch.matmul(x, x)
return y
x = torch.randn(10, 10, device="cuda")
y1 = my_model(x)
y2 = my_model(x)
print(y1)
# RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.
在单独的 CUDA Graph 实现中,第一次调用的输出将被第二次调用覆盖。在 CUDAGraph Trees 中,我们不想在迭代之间添加意外的依赖关系,导致我们无法命中热路径,也不想过早地释放先前调用的内存。我们的启发式方法是在推理中,我们为 torch.compile 在每次调用时启动一个新迭代,在训练中,只要没有待处理的后向尚未调用,我们也这样做。如果这些启发式方法不正确,您可以使用 `torch.compiler.mark_step_begin()` 标记新迭代的开始,或者在开始下一次运行之前克隆前一次迭代的张量(在 torch.compile 之外)。
比较#
易出错的陷阱 |
单独的 CudaGraph |
CUDAGraph 树 |
---|---|---|
内存可能增加 |
每次图编译时(新大小等) |
如果您也运行非 cudagraph 内存 |
记录 |
每次调用图时 |
将在您程序的任何新、唯一的路径上重新记录 |
易出错的陷阱 |
调用一个图会覆盖前一个调用 |
无法在模型的独立运行之间持久化内存 - 一个训练循环训练,或一次推理运行 |