CUDA 语义#
创建于:2017年1月16日 | 最后更新于:2025年6月18日
torch.cuda
用于设置和运行 CUDA 操作。它会跟踪当前选定的 GPU,并且您分配的所有 CUDA 张量将默认在此设备上创建。选定的设备可以使用 torch.cuda.device
上下文管理器进行更改。
但是,一旦张量被分配,您可以对其进行操作,无论选定的设备如何,结果将始终放置在与张量相同的设备上。
默认情况下不允许跨 GPU 操作,但 copy_()
和其他具有复制功能的 方法(如 to()
和 cuda()
)除外。除非您启用对等内存访问,否则任何尝试在分布在不同设备上的张量上启动操作都将引发错误。
下面是一个小示例,展示了这一点
cuda = torch.device('cuda') # Default CUDA device
cuda0 = torch.device('cuda:0')
cuda2 = torch.device('cuda:2') # GPU 2 (these are 0-indexed)
x = torch.tensor([1., 2.], device=cuda0)
# x.device is device(type='cuda', index=0)
y = torch.tensor([1., 2.]).cuda()
# y.device is device(type='cuda', index=0)
with torch.cuda.device(1):
# allocates a tensor on GPU 1
a = torch.tensor([1., 2.], device=cuda)
# transfers a tensor from CPU to GPU 1
b = torch.tensor([1., 2.]).cuda()
# a.device and b.device are device(type='cuda', index=1)
# You can also use ``Tensor.to`` to transfer a tensor:
b2 = torch.tensor([1., 2.]).to(device=cuda)
# b.device and b2.device are device(type='cuda', index=1)
c = a + b
# c.device is device(type='cuda', index=1)
z = x + y
# z.device is device(type='cuda', index=0)
# even within a context, you can specify the device
# (or give a GPU index to the .cuda call)
d = torch.randn(2, device=cuda2)
e = torch.randn(2).to(cuda2)
f = torch.randn(2).cuda(cuda2)
# d.device, e.device, and f.device are all device(type='cuda', index=2)
Ampere(及更高版本)设备上的 TensorFloat-32 (TF32)#
从 PyTorch 1.7 开始,有一个名为 allow_tf32 的新标志。此标志在 PyTorch 1.7 到 PyTorch 1.11 中默认为 True,在 PyTorch 1.12 及更高版本中默认为 False。此标志控制 PyTorch 是否允许在内部使用 TensorFloat32 (TF32) 张量核心(从 Ampere 开始在 NVIDIA GPU 上可用)来计算 matmul(矩阵乘法和批处理矩阵乘法)和卷积。
TF32 张量核心旨在通过将输入数据四舍五入为 10 位尾数,并以 FP32 精度累积结果,同时保持 FP32 动态范围,从而在 torch.float32 张量上的 matmul 和卷积中实现更好的性能。
matmul 和卷积是分开控制的,它们相应的标志可以在
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
matmul 的精度也可以通过 set_float_32_matmul_precision()
进行更广泛的设置(不限于 CUDA)。请注意,除了 matmul 和卷积本身,内部使用 matmul 或卷积的函数和 nn 模块也会受到影响。这包括 nn.Linear、nn.Conv*、cdist、tensordot、affine grid 和 grid sample、adaptive log softmax、GRU 和 LSTM。
为了了解精度和速度,请参见下面的示例代码和基准数据(在 A100 上)
a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
ab_full = a_full @ b_full
mean = ab_full.abs().mean() # 80.7277
a = a_full.float()
b = b_full.float()
# Do matmul at TF32 mode.
torch.backends.cuda.matmul.allow_tf32 = True
ab_tf32 = a @ b # takes 0.016s on GA100
error = (ab_tf32 - ab_full).abs().max() # 0.1747
relative_error = error / mean # 0.0022
# Do matmul with TF32 disabled.
torch.backends.cuda.matmul.allow_tf32 = False
ab_fp32 = a @ b # takes 0.11s on GA100
error = (ab_fp32 - ab_full).abs().max() # 0.0031
relative_error = error / mean # 0.000039
从上面的示例中,我们可以看到,启用 TF32 后,A100 上的速度快了约 7 倍,并且相对于双精度的相对误差大约大了 2 个数量级。请注意,TF32 与单精度速度的精确比率取决于硬件代,因为内存带宽与计算的比率以及 TF32 与 FP32 matmul 吞吐量的比率可能因代或型号而异。如果需要完整的 FP32 精度,用户可以通过以下方式禁用 TF32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
要在 C++ 中关闭 TF32 标志,您可以执行
at::globalContext().setAllowTF32CuBLAS(false);
at::globalContext().setAllowTF32CuDNN(false);
有关 TF32 的更多信息,请参阅
FP16 GEMM 中的降精度归约#
(与旨在用于 FP16 累积吞吐量高于 FP32 累积吞吐量的硬件的完整 FP16 累积不同,请参阅 完整 FP16 累积)
fp16 GEMM 可能会使用一些中间的降精度归约(例如,在 fp16 而不是 fp32 中)。这种选择性降精度可以在某些工作负载(特别是那些具有大 k 维的工作负载)和 GPU 架构上实现更高的性能,但代价是数值精度和潜在的溢出。
V100 上的一些示例基准数据
[--------------------------- bench_gemm_transformer --------------------------]
[ m , k , n ] | allow_fp16_reduc=True | allow_fp16_reduc=False
1 threads: --------------------------------------------------------------------
[4096, 4048, 4096] | 1634.6 | 1639.8
[4096, 4056, 4096] | 1670.8 | 1661.9
[4096, 4080, 4096] | 1664.2 | 1658.3
[4096, 4096, 4096] | 1639.4 | 1651.0
[4096, 4104, 4096] | 1677.4 | 1674.9
[4096, 4128, 4096] | 1655.7 | 1646.0
[4096, 4144, 4096] | 1796.8 | 2519.6
[4096, 5096, 4096] | 2094.6 | 3190.0
[4096, 5104, 4096] | 2144.0 | 2663.5
[4096, 5112, 4096] | 2149.1 | 2766.9
[4096, 5120, 4096] | 2142.8 | 2631.0
[4096, 9728, 4096] | 3875.1 | 5779.8
[4096, 16384, 4096] | 6182.9 | 9656.5
(times in microseconds).
如果需要全精度归约,用户可以通过以下方式禁用 fp16 GEMM 中的降精度归约
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
要在 C++ 中切换降精度归约标志,可以执行
at::globalContext().setAllowFP16ReductionCuBLAS(false);
BF16 GEMM 中的降精度归约#
BFloat16 GEMM 也存在类似的标志(如上)。请注意,此开关在 BF16 中默认设置为 True,如果您在工作负载中观察到数值不稳定,您可能希望将其设置为 False。
如果不需要降精度归约,用户可以通过以下方式禁用 bf16 GEMM 中的降精度归约
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
要在 C++ 中切换降精度归约标志,可以执行
at::globalContext().setAllowBF16ReductionCuBLAS(true);
FP16 GEMM 中的完整 FP16 累积#
某些 GPU 在所有 FP16 GEMM 累积都在 FP16 中进行时性能会提高,但代价是数值精度降低和溢出可能性更大。请注意,此设置仅对计算能力为 7.0 (Volta) 或更高版本的 GPU 有效。
可以通过以下方式启用此行为
torch.backends.cuda.matmul.allow_fp16_accumulation = True
要在 C++ 中切换降精度归约标志,可以执行
at::globalContext().setAllowFP16AccumulationCuBLAS(true);
异步执行#
默认情况下,GPU 操作是异步的。当您调用使用 GPU 的函数时,操作会被排队到特定设备,但不一定立即执行。这允许我们并行执行更多计算,包括在 CPU 或其他 GPU 上的操作。
通常,异步计算的效果对调用者是不可见的,因为 (1) 每个设备按排队顺序执行操作,并且 (2) PyTorch 在 CPU 和 GPU 之间或两个 GPU 之间复制数据时自动执行必要的同步。因此,计算将像每个操作都是同步执行一样进行。
您可以通过设置环境变量 CUDA_LAUNCH_BLOCKING=1
来强制同步计算。当 GPU 上发生错误时,这会很方便。(对于异步执行,此类错误直到操作实际执行后才报告,因此堆栈跟踪不会显示请求它的位置。)
异步计算的一个结果是,没有同步的时间测量是不准确的。为了获得精确的测量,应该在测量之前调用 torch.cuda.synchronize()
,或者使用 torch.cuda.Event
记录时间,如下所示
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
# Run some things here
end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
作为例外,一些函数(如 to()
和 copy_()
)允许显式的 non_blocking
参数,这允许调用者在不需要时绕过同步。另一个例外是 CUDA 流,下面将解释。
CUDA 流#
CUDA 流是属于特定设备的线性执行序列。通常不需要显式创建它:默认情况下,每个设备都使用自己的“默认”流。
每个流内部的操作按创建顺序串行化,但来自不同流的操作可以以任何相对顺序并发执行,除非使用显式同步函数(例如 synchronize()
或 wait_stream()
)。例如,以下代码是不正确的
cuda = torch.device('cuda')
s = torch.cuda.Stream() # Create a new stream.
A = torch.empty((100, 100), device=cuda).normal_(0.0, 1.0)
with torch.cuda.stream(s):
# sum() may start execution before normal_() finishes!
B = torch.sum(A)
当“当前流”是默认流时,PyTorch 在数据移动时会自动执行必要的同步,如上所述。但是,当使用非默认流时,用户有责任确保正确同步。此示例的修复版本是
cuda = torch.device('cuda')
s = torch.cuda.Stream() # Create a new stream.
A = torch.empty((100, 100), device=cuda).normal_(0.0, 1.0)
s.wait_stream(torch.cuda.default_stream(cuda)) # NEW!
with torch.cuda.stream(s):
B = torch.sum(A)
A.record_stream(s) # NEW!
有两个新添加。 torch.cuda.Stream.wait_stream()
调用确保 normal_()
执行在我们在侧流上开始运行 sum(A)
之前完成。 torch.Tensor.record_stream()
(有关更多详细信息,请参见)确保我们在 sum(A)
完成之前不会释放 A。您还可以稍后通过 torch.cuda.default_stream(cuda).wait_stream(s)
手动等待流(请注意,立即等待是没有意义的,因为那会阻止流执行与默认流上的其他工作并行运行。)有关何时使用其中一个或另一个的更多详细信息,请参阅 torch.Tensor.record_stream()
的文档。
请注意,即使没有读取依赖项,这种同步也是必要的,例如,如此示例所示
cuda = torch.device('cuda')
s = torch.cuda.Stream() # Create a new stream.
A = torch.empty((100, 100), device=cuda)
s.wait_stream(torch.cuda.default_stream(cuda)) # STILL REQUIRED!
with torch.cuda.stream(s):
A.normal_(0.0, 1.0)
A.record_stream(s)
尽管在 s
上的计算没有读取 A
的内容,并且没有其他使用 A
的情况,仍然需要同步,因为 A
可能对应于由 CUDA 缓存分配器重新分配的内存,其中包含来自旧(已释放)内存的挂起操作。
反向传播的流语义#
每个反向 CUDA 操作都在与其对应的正向操作相同的流上运行。如果您的正向传播在不同流上并行运行独立操作,这有助于反向传播利用相同的并行性。
反向调用与周围操作的流语义与任何其他调用相同。反向传播插入内部同步以确保即使反向操作如上一段所述在多个流上运行也如此。更具体地说,当调用 autograd.backward
、autograd.grad
或 tensor.backward
,并可选地提供 CUDA 张量作为初始梯度(例如,autograd.backward(..., grad_tensors=initial_grads)
、autograd.grad(..., grad_outputs=initial_grads)
或 tensor.backward(..., gradient=initial_grad)
),以下行为
可选地填充初始梯度,
调用反向传播,以及
使用梯度
与任何一组操作具有相同的流语义关系
s = torch.cuda.Stream()
# Safe, grads are used in the same stream context as backward()
with torch.cuda.stream(s):
loss.backward()
use grads
# Unsafe
with torch.cuda.stream(s):
loss.backward()
use grads
# Safe, with synchronization
with torch.cuda.stream(s):
loss.backward()
torch.cuda.current_stream().wait_stream(s)
use grads
# Safe, populating initial grad and invoking backward are in the same stream context
with torch.cuda.stream(s):
loss.backward(gradient=torch.ones_like(loss))
# Unsafe, populating initial_grad and invoking backward are in different stream contexts,
# without synchronization
initial_grad = torch.ones_like(loss)
with torch.cuda.stream(s):
loss.backward(gradient=initial_grad)
# Safe, with synchronization
initial_grad = torch.ones_like(loss)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
initial_grad.record_stream(s)
loss.backward(gradient=initial_grad)
向后兼容性说明:在默认流上使用梯度#
在 PyTorch 的早期版本(1.9 及更早版本)中,自动梯度引擎总是将默认流与所有反向操作同步,因此以下模式
with torch.cuda.stream(s):
loss.backward()
use grads
只要 use grads
发生在默认流上,就是安全的。在当前的 PyTorch 中,这种模式不再安全。如果 backward()
和 use grads
处于不同的流上下文中,您必须同步流
with torch.cuda.stream(s):
loss.backward()
torch.cuda.current_stream().wait_stream(s)
use grads
即使 use grads
在默认流上也是如此。
内存管理#
PyTorch 使用缓存内存分配器来加速内存分配。这允许快速内存释放而无需设备同步。但是,分配器管理的未使用内存仍将在 nvidia-smi
中显示为已使用。您可以使用 memory_allocated()
和 max_memory_allocated()
监视张量占用的内存,并使用 memory_reserved()
和 max_memory_reserved()
监视缓存分配器管理的总内存量。调用 empty_cache()
会释放 PyTorch 中所有未使用的缓存内存,以便其他 GPU 应用程序可以使用。但是,张量占用的 GPU 内存不会被释放,因此无法增加 PyTorch 可用的 GPU 内存量。
为了更好地理解 CUDA 内存随时间的使用情况,了解 CUDA 内存使用情况 描述了用于捕获和可视化内存使用跟踪的工具。
对于高级用户,我们通过 memory_stats()
提供更全面的内存基准测试。我们还提供通过 memory_snapshot()
捕获内存分配器状态完整快照的功能,这可以帮助您了解代码生成的底层分配模式。
使用 PYTORCH_CUDA_ALLOC_CONF
优化内存使用#
使用缓存分配器可能会干扰内存检查工具,例如 cuda-memcheck
。要使用 cuda-memcheck
调试内存错误,请在您的环境中设置 PYTORCH_NO_CUDA_MEMORY_CACHING=1
以禁用缓存。
缓存分配器的行为可以通过环境变量 PYTORCH_CUDA_ALLOC_CONF
控制。格式为 PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...
可用选项
backend
允许选择底层分配器实现。当前有效选项是native
(使用 PyTorch 的原生实现)和cudaMallocAsync
(使用 CUDA 内置的异步分配器)。cudaMallocAsync
需要 CUDA 11.4 或更高版本。默认值是native
。backend
适用于进程使用的所有设备,不能按设备指定。max_split_size_mb
阻止原生分配器分割大于此大小(MB)的块。这可以减少碎片化,并可能允许一些临界工作负载在不耗尽内存的情况下完成。性能成本从“零”到“显著”不等,具体取决于分配模式。默认值是无限制,即所有块都可以分割。memory_stats()
和memory_summary()
方法对于调优很有用。此选项应作为工作负载因“内存不足”而中止且显示大量非活动分割块时的最后手段。max_split_size_mb
仅在使用backend:native
时有意义。在使用backend:cudaMallocAsync
时,max_split_size_mb
被忽略。roundup_power2_divisions
有助于将请求的分配大小四舍五入到最近的 2 次方分割,并更好地利用块。在原生 CUDACachingAllocator 中,大小以 512 的块大小的倍数四舍五入,因此这对于较小的大小很好。但是,对于大型邻近分配,这可能效率低下,因为每个分配都将进入不同大小的块,并且这些块的重用最小化。这可能会创建大量未使用的块并浪费 GPU 内存容量。此选项启用将分配大小四舍五入到最近的 2 次方分割。例如,如果我们需要将大小 1200 四舍五入,并且分割数为 4,则大小 1200 介于 1024 和 2048 之间,如果我们在它们之间进行 4 次分割,则值为 1024、1280、1536 和 1792。因此,分配大小 1200 将四舍五入到 1280 作为最接近的 2 次方分割上限。指定单个值以应用于所有分配大小,或指定键值对数组以单独设置每个 2 次方间隔的 2 次方分割。例如,要为所有小于 256MB 的分配设置 1 次分割,为 256MB 到 512MB 之间的分配设置 2 次分割,为 512MB 到 1GB 之间的分配设置 4 次分割,以及为任何更大的分配设置 8 次分割,请将此选项值设置为:[256:1,512:2,1024:4,>:8]。roundup_power2_divisions
仅在使用backend:native
时有意义。在使用backend:cudaMallocAsync
时,roundup_power2_divisions
被忽略。max_non_split_rounding_mb
将允许非分割块更好地重用,例如,一个 1024MB 的缓存块可以用于 512MB 的分配请求。在默认情况下,我们只允许非分割块进行最大 20MB 的舍入,因此 512MB 的块只能使用 512-532 MB 大小的块。如果我们将此选项的值设置为 1024,它将允许 512-1536 MB 大小的块用于 512MB 的块,这增加了较大块的重用。这还将有助于减少避免昂贵的 cudaMalloc 调用的停滞。
garbage_collection_threshold
有助于主动回收未使用的 GPU 内存,以避免触发昂贵的同步和回收所有操作(release_cached_blocks),这可能不利于对延迟敏感的 GPU 应用程序(例如服务器)。设置此阈值(例如 0.8)后,如果 GPU 内存容量使用量超过阈值(即分配给 GPU 应用程序的总内存的 80%),分配器将开始回收 GPU 内存块。该算法优先释放旧的未使用的块,以避免释放正在积极重用的块。阈值应大于 0.0 且小于 1.0。默认值为 1.0。garbage_collection_threshold
仅在使用backend:native
时有意义。在使用backend:cudaMallocAsync
时,garbage_collection_threshold
被忽略。expandable_segments
(实验性,默认值:False)如果设置为 True,此设置指示分配器创建 CUDA 分配,这些分配以后可以扩展,以更好地处理作业频繁更改分配大小的情况,例如更改批量大小。通常对于大型(>2MB)分配,分配器调用 cudaMalloc 获取与用户请求大小相同的分配。将来,如果这些分配空闲,它们的某些部分可以重用于其他请求。当程序发出许多完全相同大小或大小为该大小的偶数倍的请求时,这很有效。许多深度学习模型都遵循此行为。但是,一个常见的例外是当批量大小从一个迭代到下一个迭代略有变化时,例如在批推理中。当程序最初以批量大小 N 运行时,它将进行适合该大小的分配。如果将来,它以大小 N - 1 运行,则现有分配仍将足够大。但是,如果它以大小 N + 1 运行,那么它将不得不进行略大的新分配。并非所有张量都相同大小。有些可能是 (N + 1)*A,而另一些是 (N + 1)*A*B,其中 A 和 B 是模型中的一些非批量维度。因为分配器在现有分配足够大时重用它们,所以一些 (N + 1)*A 分配实际上将适合已存在的 N*B*A 段,尽管不完美。随着模型的运行,它将部分填充所有这些段,在这些段的末尾留下不可用的空闲内存片。分配器在某个时候需要 cudaMalloc 一个新的 (N + 1)*A*B 段。如果内存不足,现在无法恢复现有段末尾的空闲内存片。对于 50 层以上的深层模型,这种模式可能会重复 50 次以上,产生许多碎片。expandable_segments 允许分配器最初创建一个段,然后在需要更多内存时扩展其大小。它不是为每个分配创建一个段,而是尝试创建一个(每流)按需增长的段。现在,当 N + 1 的情况运行时,分配将整齐地平铺到一个大段中,直到它填满。然后请求更多内存并附加到段的末尾。此过程不会创建那么多不可用内存的碎片,因此更有可能成功找到此内存。
pinned_use_cuda_host_register 选项是一个布尔标志,用于确定是使用 CUDA API 的 cudaHostRegister 函数分配固定内存,还是使用默认的 cudaHostAlloc。当设置为 True 时,内存使用常规 malloc 分配,然后页面在调用 cudaHostRegister 之前映射到内存。这种预映射页面有助于减少 cudaHostRegister 执行期间的锁定时间。
pinned_num_register_threads 选项仅在 pinned_use_cuda_host_register 设置为 True 时有效。默认情况下,一个线程用于映射页面。此选项允许使用更多线程并行化页面映射操作,以减少固定内存的整体分配时间。根据基准测试结果,此选项的良好值为 8。
pinned_use_background_threads 选项是一个布尔标志,用于启用后台线程处理事件。这避免了与快速分配路径中的事件查询/处理相关的任何慢速路径。此功能默认禁用。
注意
CUDA 内存管理 API 报告的一些统计信息特定于 backend:native
,并且对于 backend:cudaMallocAsync
没有意义。有关详细信息,请参阅每个函数的文档字符串。
使用 CUDA 的自定义内存分配器#
可以将分配器定义为 C/C++ 中的简单函数,并将其编译为共享库,下面的代码展示了一个只跟踪所有内存操作的基本分配器。
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
// Compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC
extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
void *ptr;
cudaMalloc(&ptr, size);
std::cout<<"alloc "<<ptr<<size<<std::endl;
return ptr;
}
void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) {
std::cout<<"free "<<ptr<< " "<<stream<<std::endl;
cudaFree(ptr);
}
}
这可以通过 torch.cuda.memory.CUDAPluggableAllocator
在 Python 中使用。用户负责提供 .so 文件的路径以及与上述签名匹配的 alloc/free 函数的名称。
import torch
# Load the allocator
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
'alloc.so', 'my_malloc', 'my_free')
# Swap the current allocator
torch.cuda.memory.change_current_allocator(new_alloc)
# This will allocate memory in the device using the new allocator
b = torch.zeros(10, device='cuda')
import torch
# Do an initial memory allocator
b = torch.zeros(10, device='cuda')
# Load the allocator
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
'alloc.so', 'my_malloc', 'my_free')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(new_alloc)
在同一程序中混合使用不同的 CUDA 系统分配器#
根据您的用例,change_current_allocator()
可能不是您想要使用的,因为它会为整个程序交换 CUDA 分配器(类似于 PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync
)。例如,如果交换的分配器没有缓存机制,您将失去 PyTorch 的 CUDACachingAllocator 的所有好处。相反,您可以使用 torch.cuda.MemPool
选择性地标记 PyTorch 代码区域以使用自定义分配器。这将允许您在同一个 PyTorch 程序中使用多个 CUDA 系统分配器,并获得 CUDACachingAllocator 的大部分好处(例如缓存)。使用 torch.cuda.MemPool
,您可以利用自定义分配器,从而启用多项功能,例如
使用
ncclMemAlloc
分配器为 all-reduce 分配输出缓冲区可以启用 NVLink 交换归约 (NVLS)。这可以减少重叠计算和通信内核在 GPU 资源(SM 和复制引擎)上的竞争,尤其是在张量并行工作负载中。对于基于 Grace CPU 的系统,使用
cuMemCreate
并指定CU_MEM_LOCATION_TYPE_HOST_NUMA
为 all-gather 分配主机输出缓冲区可以启用基于扩展 GPU 内存 (EGM) 的从源 GPU 到目标 CPU 的内存传输。这加速了 all-gather,因为传输通过 NVLinks 进行,否则将通过带宽受限的网络接口卡 (NIC) 链接进行。这种加速的 all-gather 反过来可以加快模型检查点的速度。如果您正在构建模型并且最初不想考虑内存密集型模块(例如嵌入表)的最佳内存放置,或者您有一个对性能不敏感且不适合 GPU 的模块,那么您可以首先使用
cudaMallocManaged
并指定首选 CPU 位置来分配该模块,然后让您的模型正常工作。
注意
虽然 cudaMallocManaged
提供了使用 CUDA 统一虚拟内存 (UVM) 的便捷自动内存管理,但它不建议用于 DL 工作负载。对于适合 GPU 内存的 DL 工作负载,显式放置始终优于 UVM,因为没有页面错误且访问模式保持可预测。当 GPU 内存饱和时,UVM 必须执行代价高昂的双重传输,即在引入新页面之前将页面驱逐到 CPU。
下面的代码显示了 ncclMemAlloc
被封装在 torch.cuda.memory.CUDAPluggableAllocator
中。
import os
import torch
import torch.distributed as dist
from torch.cuda.memory import CUDAPluggableAllocator
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils import cpp_extension
# create allocator
nccl_allocator_source = """
#include <nccl.h>
#include <iostream>
extern "C" {
void* nccl_alloc_plug(size_t size, int device, void* stream) {
std::cout << "Using ncclMemAlloc" << std::endl;
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
}
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
std::cout << "Using ncclMemFree" << std::endl;
ncclResult_t err = ncclMemFree(ptr);
}
}
"""
nccl_allocator_libname = "nccl_allocator"
nccl_allocator = torch.utils.cpp_extension.load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
verbose=True,
is_python_module=False,
build_directory="./",
)
allocator = CUDAPluggableAllocator(
f"./{nccl_allocator_libname}.so", "nccl_alloc_plug", "nccl_free_plug"
).allocator()
# setup distributed
rank = int(os.getenv("RANK"))
local_rank = int(os.getenv("LOCAL_RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl")
device = torch.device(f"cuda:{local_rank}")
default_pg = _get_default_group()
backend = default_pg._get_backend(device)
# Note: for convenience, ProcessGroupNCCL backend provides
# the ncclMemAlloc allocator as backend.mem_allocator
allocator = backend.mem_allocator
现在可以通过将此分配器传递给 torch.cuda.MemPool
来定义新的内存池
pool = torch.cuda.MemPool(allocator)
然后可以将池与 torch.cuda.use_mem_pool
上下文管理器一起使用,以将张量分配到该池中
with torch.cuda.use_mem_pool(pool):
# tensor gets allocated with ncclMemAlloc passed in the pool
tensor = torch.arange(1024 * 1024 * 2, device=device)
print(f"tensor ptr on rank {rank} is {hex(tensor.data_ptr())}")
# register user buffers using ncclCommRegister (called under the hood)
backend.register_mem_pool(pool)
# Collective uses Zero Copy NVLS
dist.all_reduce(tensor[0:4])
torch.cuda.synchronize()
print(tensor[0:4])
请注意上面示例中 register_mem_pool
的用法。这是 NVLS 归约的额外步骤,其中用户缓冲区需要注册到 NCCL。用户可以使用类似的 deregister_mem_pool
调用来取消注册缓冲区。
要回收内存,用户首先需要确保没有任何东西正在使用该池。当没有张量持有对该池的引用时,empty_cache()
将在池删除时内部调用,从而将所有内存返回给系统。
del tensor, del pool
用户可以在 MemPool 创建期间可选地指定 use_on_oom
布尔值(默认为 False)。如果为 true,则 CUDACachingAllocator 将能够将此池中的内存作为最后的手段,而不是 OOM。
pool = torch.cuda.MemPool(allocator, use_on_oom=True)
with torch.cuda.use_mem_pool(pool):
a = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
del a
# at the memory limit, this will succeed by using pool's memory in order to avoid the oom
b = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
以下 torch.cuda.MemPool.use_count()
和 torch.cuda.MemPool.snapshot()
API 可用于调试目的
pool = torch.cuda.MemPool(allocator)
# pool's use count should be 1 at this point as MemPool object
# holds a reference
assert pool.use_count() == 1
nelem_1mb = 1024 * 1024 // 4
with torch.cuda.use_mem_pool(pool):
out_0 = torch.randn(nelem_1mb, device="cuda")
# pool's use count should be 2 at this point as use_mem_pool
# holds a reference
assert pool.use_count() == 2
# pool's use count should be back to 1 at this point as use_mem_pool
# released its reference
assert pool.use_count() == 1
with torch.cuda.use_mem_pool(pool):
# pool should have 1 segment since we made a small allocation (1 MB)
# above and so the CUDACachingAllocator packed it into a 2 MB buffer
assert len(pool.snapshot()) == 1
out_1 = torch.randn(nelem_1mb, device="cuda")
# pool should still have 1 segment since we made another small allocation
# (1 MB) that got packed into the existing 2 MB buffer
assert len(pool.snapshot()) == 1
out_2 = torch.randn(nelem_1mb, device="cuda")
# pool now should have 2 segments since the CUDACachingAllocator had
# to make a new 2 MB buffer to accommodate out_2
assert len(pool.snapshot()) == 2
注意
torch.cuda.MemPool
持有对池的引用。当您使用torch.cuda.use_mem_pool
上下文管理器时,它也将获取对池的另一个引用。退出上下文管理器时,它将释放其引用。在那之后,理想情况下,只有张量才持有对池的引用。一旦张量释放了它们的引用,池的使用计数将为 1,反映出只有torch.cuda.MemPool
对象持有引用。只有在那时,当调用池的析构函数时,池持有的内存才能通过del
返回给系统。torch.cuda.MemPool
目前不支持 CUDACachingAllocator 的expandable_segments
模式。NCCL 对缓冲区有一些特定要求,以使其与 NVLS 归约兼容。这些要求在动态工作负载中可能会被打破,例如,由 CUDACachingAllocator 发送给 NCCL 的缓冲区可能被分割,因此未正确对齐。在这些情况下,NCCL 可以使用回退算法而不是 NVLS。
像
ncclMemAlloc
这样的分配器可能会由于对齐要求(CU_MULTICAST_GRANULARITY_RECOMMENDED
、CU_MULTICAST_GRANULARITY_MINIMUM
)而使用比请求更多的内存,这可能导致您的工作负载内存不足。
cuBLAS 工作区#
对于 cuBLAS 句柄和 CUDA 流的每个组合,如果该句柄和流组合执行需要工作区的 cuBLAS 内核,则将分配一个 cuBLAS 工作区。为了避免重复分配工作区,这些工作区不会被释放,除非调用 torch._C._cuda_clearCublasWorkspaces()
。每个分配的工作区大小可以通过环境变量 CUBLAS_WORKSPACE_CONFIG
指定,格式为 :[SIZE]:[COUNT]
。例如,每个分配的默认工作区大小是 CUBLAS_WORKSPACE_CONFIG=:4096:2:16:8
,它指定总大小为 2 * 4096 + 8 * 16 KiB
。要强制 cuBLAS 避免使用工作区,请设置 CUBLAS_WORKSPACE_CONFIG=:0:0
。
cuFFT 计划缓存#
对于每个 CUDA 设备,使用 cuFFT 计划的 LRU 缓存来加速在具有相同配置的相同几何形状的 CUDA 张量上重复运行 FFT 方法(例如,torch.fft.fft()
)。由于某些 cuFFT 计划可能会分配 GPU 内存,因此这些缓存具有最大容量。
您可以使用以下 API 控制和查询当前设备的缓存属性
torch.backends.cuda.cufft_plan_cache.max_size
给出缓存的容量(CUDA 10 及更高版本默认为 4096,旧版 CUDA 默认为 1023)。直接设置此值会修改容量。torch.backends.cuda.cufft_plan_cache.size
给出缓存中当前存在的计划数量。torch.backends.cuda.cufft_plan_cache.clear()
清除缓存。
要控制和查询非默认设备的计划缓存,您可以使用 torch.device
对象或设备索引来索引 torch.backends.cuda.cufft_plan_cache
对象,并访问上述属性之一。例如,要设置设备 1
的缓存容量,可以写入 torch.backends.cuda.cufft_plan_cache[1].max_size = 10
。
即时编译#
PyTorch 在对 CUDA 张量执行某些操作(如 torch.special.zeta)时会进行即时编译。此编译可能非常耗时(根据您的硬件和软件,可能长达几秒钟),并且单个操作可能会多次发生,因为许多 PyTorch 操作实际上会从各种内核中进行选择,每个内核都必须编译一次,具体取决于它们的输入。此编译每个进程发生一次,或者如果使用内核缓存则仅发生一次。
默认情况下,如果定义了 XDG_CACHE_HOME,PyTorch 会在 $XDG_CACHE_HOME/torch/kernels 中创建内核缓存;如果未定义,则在 $HOME/.cache/torch/kernels 中创建(Windows 除外,该系统尚未支持内核缓存)。缓存行为可以通过两个环境变量直接控制。如果 USE_PYTORCH_KERNEL_CACHE 设置为 0,则不会使用缓存;如果设置了 PYTORCH_KERNEL_CACHE_PATH,则该路径将用作内核缓存,而不是默认位置。
最佳实践#
设备无关代码#
由于 PyTorch 的结构,您可能需要显式编写设备无关(CPU 或 GPU)代码;例如,将新张量创建为循环神经网络的初始隐藏状态。
第一步是确定是否应该使用 GPU。一个常见的模式是使用 Python 的 argparse
模块读取用户参数,并设置一个标志,可用于禁用 CUDA,结合 is_available()
。在下文中,args.device
导致一个 torch.device
对象,可用于将张量移动到 CPU 或 CUDA。
import argparse
import torch
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--disable-cuda', action='store_true',
help='Disable CUDA')
args = parser.parse_args()
args.device = None
if not args.disable_cuda and torch.cuda.is_available():
args.device = torch.device('cuda')
else:
args.device = torch.device('cpu')
注意
在评估给定环境中 CUDA 的可用性(is_available()
)时,PyTorch 的默认行为是调用 CUDA 运行时 API 方法 cudaGetDeviceCount。由于此调用反过来会初始化 CUDA 驱动程序 API(通过 cuInit)(如果尚未初始化),因此运行过 is_available()
的进程后续分叉将失败并出现 CUDA 初始化错误。
您可以在导入执行 is_available()
的 PyTorch 模块之前(或在直接执行它之前)在环境中设置 PYTORCH_NVML_BASED_CUDA_CHECK=1
,以指示 is_available()
尝试进行基于 NVML 的评估(nvmlDeviceGetCount_v2)。如果基于 NVML 的评估成功(即 NVML 发现/初始化未失败),则 is_available()
调用将不会毒化后续的分叉。
如果 NVML 发现/初始化失败,is_available()
将回退到标准的 CUDA 运行时 API 评估,并且上述分叉限制将适用。
请注意,上述基于 NVML 的 CUDA 可用性评估提供了比默认 CUDA 运行时 API 方法(需要 CUDA 初始化成功)更弱的保证。在某些情况下,基于 NVML 的检查可能成功,而后续的 CUDA 初始化失败。
现在我们有了 args.device
,我们可以用它在所需设备上创建张量。
x = torch.empty((8, 42), device=args.device)
net = Network().to(device=args.device)
这可以在许多情况下用于生成与设备无关的代码。以下是使用数据加载器时的示例
cuda0 = torch.device('cuda:0') # CUDA GPU 0
for i, x in enumerate(train_loader):
x = x.to(cuda0)
在系统上使用多个 GPU 时,可以使用 CUDA_VISIBLE_DEVICES
环境变量来管理哪些 GPU 可用于 PyTorch。如上所述,要手动控制在哪个 GPU 上创建张量,最佳实践是使用 torch.cuda.device
上下文管理器。
print("Outside device is 0") # On device 0 (default in most scenarios)
with torch.cuda.device(1):
print("Inside device is 1") # On device 1
print("Outside device is still 0") # On device 0
如果您有一个张量,并且希望在同一设备上创建相同类型的新张量,那么可以使用 torch.Tensor.new_*
方法(参见 torch.Tensor
)。虽然前面提到的 torch.*
工厂函数(创建操作)取决于当前的 GPU 上下文和您传入的属性参数,但 torch.Tensor.new_*
方法会保留张量的设备和其他属性。
这是在创建需要在前向传播期间内部创建新张量的模块时推荐的做法。
cuda = torch.device('cuda')
x_cpu = torch.empty(2)
x_gpu = torch.empty(2, device=cuda)
x_cpu_long = torch.empty(2, dtype=torch.int64)
y_cpu = x_cpu.new_full([3, 2], fill_value=0.3)
print(y_cpu)
tensor([[ 0.3000, 0.3000],
[ 0.3000, 0.3000],
[ 0.3000, 0.3000]])
y_gpu = x_gpu.new_full([3, 2], fill_value=-5)
print(y_gpu)
tensor([[-5.0000, -5.0000],
[-5.0000, -5.0000],
[-5.0000, -5.0000]], device='cuda:0')
y_cpu_long = x_cpu_long.new_tensor([[1, 2, 3]])
print(y_cpu_long)
tensor([[ 1, 2, 3]])
如果您想创建一个与另一个张量类型和大小相同,并填充为全 1 或全 0 的张量,则提供了 ones_like()
或 zeros_like()
作为方便的辅助函数(它们也保留张量的 torch.device
和 torch.dtype
)。
x_cpu = torch.empty(2, 3)
x_gpu = torch.empty(2, 3)
y_cpu = torch.ones_like(x_cpu)
y_gpu = torch.zeros_like(x_gpu)
使用固定内存缓冲区#
警告
这是一个高级提示。如果您过度使用固定内存,在 RAM 不足时可能会导致严重问题,并且您应该意识到固定通常是一个昂贵的操作。
当从固定(页面锁定)内存发起时,主机到 GPU 的复制速度要快得多。CPU 张量和存储公开了一个 pin_memory()
方法,该方法返回对象的副本,数据放在固定区域中。
此外,一旦您固定了张量或存储,您就可以使用异步 GPU 复制。只需将额外的 non_blocking=True
参数传递给 to()
或 cuda()
调用。这可用于将数据传输与计算重叠。
您可以通过将 pin_memory=True
传递给 DataLoader
的构造函数,使其返回放置在固定内存中的批次。
使用 nn.parallel.DistributedDataParallel 而不是 multiprocessing 或 nn.DataParallel#
大多数涉及批量输入和多个 GPU 的用例应默认使用 DistributedDataParallel
来利用多个 GPU。
将 CUDA 模型与 multiprocessing
一起使用存在重大注意事项;除非严格满足数据处理要求,否则您的程序很可能出现不正确或未定义的行为。
建议使用 DistributedDataParallel
而不是 DataParallel
进行多 GPU 训练,即使只有一个节点。
DistributedDataParallel
和 DataParallel
的区别在于:DistributedDataParallel
使用多进程,为每个 GPU 创建一个进程,而 DataParallel
使用多线程。通过使用多进程,每个 GPU 都有其专用进程,这避免了 Python 解释器 GIL 引起的性能开销。
如果您使用 DistributedDataParallel
,您可以使用 torch.distributed.launch 实用程序来启动您的程序,请参阅 启动实用程序。
CUDA 图#
CUDA 图是对 CUDA 流及其依赖流执行的工作(主要是内核及其参数)的记录。有关基本原理和底层 CUDA API 的详细信息,请参阅 CUDA 图入门 和 CUDA C 编程指南的 图部分。
PyTorch 支持使用 流捕获 来构建 CUDA 图,这会使 CUDA 流进入捕获模式。提交到捕获流的 CUDA 工作实际上不会在 GPU 上运行。相反,工作会被记录在图中。
捕获后,可以启动图以运行 GPU 工作所需的次数。每次重放都会以相同的参数运行相同的内核。对于指针参数,这意味着使用相同的内存地址。通过在每次重放之前用新数据填充输入内存(例如,来自新批次),您可以在新数据上重新运行相同的工作。
为什么使用 CUDA 图?#
重放图牺牲了典型急切执行的动态灵活性,以换取**大大减少的 CPU 开销**。图的参数和内核是固定的,因此图重放跳过了所有层的参数设置和内核调度,包括 Python、C++ 和 CUDA 驱动程序开销。在底层,重放通过一次调用 cudaGraphLaunch 将整个图的工作提交给 GPU。重放中的内核在 GPU 上执行速度也略快,但消除 CPU 开销是主要好处。
如果您的网络全部或部分是图安全的(通常这意味着静态形状和静态控制流,但请参见其他限制),并且您怀疑其运行时至少在某种程度上受到 CPU 限制,那么您应该尝试使用 CUDA 图。
PyTorch API#
警告
此 API 处于 Beta 版,未来版本中可能会更改。
PyTorch 通过原始的 torch.cuda.CUDAGraph
类和两个方便的包装器 torch.cuda.graph
和 torch.cuda.make_graphed_callables
公开图。
torch.cuda.graph
是一个简单、通用的上下文管理器,可在其上下文中捕获 CUDA 工作。在捕获之前,通过运行一些急切的迭代来预热要捕获的工作负载。预热必须在侧流上进行。因为图在每次重放中都从相同的内存地址读取和写入,所以您必须保持对在捕获期间保存输入和输出数据的张量的长期引用。要在新输入数据上运行图,请将新数据复制到捕获的输入张量,重放图,然后从捕获的输出张量中读取新输出。示例
g = torch.cuda.CUDAGraph()
# Placeholder input used for capture
static_input = torch.empty((5,), device="cuda")
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
static_output = static_input * 2
torch.cuda.current_stream().wait_stream(s)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with torch.cuda.graph(g):
static_output = static_input * 2
# Fills the graph's input memory with new data to compute on
static_input.copy_(torch.full((5,), 3, device="cuda"))
g.replay()
# static_output holds the results
print(static_output) # full of 3 * 2 = 6
# Fills the graph's input memory with more data to compute on
static_input.copy_(torch.full((5,), 4, device="cuda"))
g.replay()
print(static_output) # full of 4 * 2 = 8
有关实际和高级模式,请参阅全网络捕获、与 torch.cuda.amp 的用法 和与多个流的用法。
make_graphed_callables
更为复杂。make_graphed_callables
接受 Python 函数和 torch.nn.Module
。对于每个传递的函数或模块,它会创建前向和后向工作的独立图。请参阅部分网络捕获。
约束#
如果一组操作不违反以下任何约束,则它就是可捕获的。
约束适用于 torch.cuda.graph
上下文中的所有工作,以及您传递给 torch.cuda.make_graphed_callables()
的任何可调用函数的前向和后向传播中的所有工作。
违反以下任何一项都可能导致运行时错误
捕获必须在非默认流上进行。(这仅当您使用原始
CUDAGraph.capture_begin
和CUDAGraph.capture_end
调用时才需要考虑。graph
和make_graphed_callables()
会为您设置一个侧流。)禁止同步 CPU 与 GPU 的操作(例如,
.item()
调用)。允许 CUDA RNG 操作,当在图中使用多个
torch.Generator
实例时,它们必须在图捕获之前使用CUDAGraph.register_generator_state
进行注册。避免在捕获期间使用Generator.get_state
和Generator.set_state
;相反,请使用Generator.graphsafe_set_state
和Generator.graphsafe_get_state
在图上下文中安全地管理生成器状态。这确保了 CUDA 图中正确的 RNG 操作和生成器管理。
违反以下任何一项都可能导致静默数值错误或未定义的行为
在一个进程内,一次只能进行一次捕获。
在捕获进行时,此进程(在任何线程上)不得运行非捕获的 CUDA 工作。
CPU 工作未被捕获。如果捕获的操作包含 CPU 工作,则该工作将在重放期间被省略。
每次重放都从相同的(虚拟)内存地址读取和写入。
禁止动态控制流(基于 CPU 或 GPU 数据)。
禁止动态形状。图假定捕获的操作序列中的每个张量在每次重放中都具有相同的大小和布局。
允许在一次捕获中使用多个流,但存在限制。
非约束#
一旦捕获,图可以在任何流上重放。
全网络捕获#
如果您的整个网络都是可捕获的,您可以捕获并重放整个迭代
N, D_in, H, D_out = 640, 4096, 2048, 1024
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
torch.nn.Dropout(p=0.2),
torch.nn.Linear(H, D_out),
torch.nn.Dropout(p=0.1)).cuda()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# Placeholders used for capture
static_input = torch.randn(N, D_in, device='cuda')
static_target = torch.randn(N, D_out, device='cuda')
# warmup
# Uses static_input and static_target here for convenience,
# but in a real setting, because the warmup includes optimizer.step()
# you must use a few batches of real data.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
optimizer.zero_grad(set_to_none=True)
y_pred = model(static_input)
loss = loss_fn(y_pred, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# capture
g = torch.cuda.CUDAGraph()
# Sets grads to None before capture, so backward() will create
# .grad attributes with allocations from the graph's private pool
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
static_y_pred = model(static_input)
static_loss = loss_fn(static_y_pred, static_target)
static_loss.backward()
optimizer.step()
real_inputs = [torch.rand_like(static_input) for _ in range(10)]
real_targets = [torch.rand_like(static_target) for _ in range(10)]
for data, target in zip(real_inputs, real_targets):
# Fills the graph's input memory with new data to compute on
static_input.copy_(data)
static_target.copy_(target)
# replay() includes forward, backward, and step.
# You don't even need to call optimizer.zero_grad() between iterations
# because the captured backward refills static .grad tensors in place.
g.replay()
# Params have been updated. static_y_pred, static_loss, and .grad
# attributes hold values from computing on this iteration's data.
部分网络捕获#
如果您的部分网络不适合捕获(例如,由于动态控制流、动态形状、CPU 同步或必要的 CPU 端逻辑),您可以急切地运行不安全的部分,并使用 torch.cuda.make_graphed_callables()
仅捕获安全部分并将其图形化。
默认情况下,make_graphed_callables()
返回的可调用对象是 autograd 感知的,可以在训练循环中直接替换您传入的函数或 nn.Module
。
make_graphed_callables()
内部创建 CUDAGraph
对象,运行预热迭代,并根据需要维护静态输入和输出。因此(与 torch.cuda.graph
不同),您无需手动处理这些。
在以下示例中,数据相关的动态控制流意味着网络无法端到端捕获,但 make_graphed_callables()
允许我们无论如何都能捕获并运行图安全的部分作为图。
N, D_in, H, D_out = 640, 4096, 2048, 1024
module1 = torch.nn.Linear(D_in, H).cuda()
module2 = torch.nn.Linear(H, D_out).cuda()
module3 = torch.nn.Linear(H, D_out).cuda()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(chain(module1.parameters(),
module2.parameters(),
module3.parameters()),
lr=0.1)
# Sample inputs used for capture
# requires_grad state of sample inputs must match
# requires_grad state of real inputs each callable will see.
x = torch.randn(N, D_in, device='cuda')
h = torch.randn(N, H, device='cuda', requires_grad=True)
module1 = torch.cuda.make_graphed_callables(module1, (x,))
module2 = torch.cuda.make_graphed_callables(module2, (h,))
module3 = torch.cuda.make_graphed_callables(module3, (h,))
real_inputs = [torch.rand_like(x) for _ in range(10)]
real_targets = [torch.randn(N, D_out, device="cuda") for _ in range(10)]
for data, target in zip(real_inputs, real_targets):
optimizer.zero_grad(set_to_none=True)
tmp = module1(data) # forward ops run as a graph
if tmp.sum().item() > 0:
tmp = module2(tmp) # forward ops run as a graph
else:
tmp = module3(tmp) # forward ops run as a graph
loss = loss_fn(tmp, target)
# module2's or module3's (whichever was chosen) backward ops,
# as well as module1's backward ops, run as graphs
loss.backward()
optimizer.step()
与 torch.cuda.amp 的用法#
对于典型的优化器,GradScaler.step
会使 CPU 与 GPU 同步,这在捕获期间是被禁止的。为避免错误,可以使用 部分网络捕获,或者(如果前向、损失和反向传播是捕获安全的)捕获前向、损失和反向传播,但不捕获优化器步骤。
# warmup
# In a real setting, use a few batches of real data.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
y_pred = model(static_input)
loss = loss_fn(y_pred, static_target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.current_stream().wait_stream(s)
# capture
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
with torch.cuda.amp.autocast():
static_y_pred = model(static_input)
static_loss = loss_fn(static_y_pred, static_target)
scaler.scale(static_loss).backward()
# don't capture scaler.step(optimizer) or scaler.update()
real_inputs = [torch.rand_like(static_input) for _ in range(10)]
real_targets = [torch.rand_like(static_target) for _ in range(10)]
for data, target in zip(real_inputs, real_targets):
static_input.copy_(data)
static_target.copy_(target)
g.replay()
# Runs scaler.step and scaler.update eagerly
scaler.step(optimizer)
scaler.update()
与多个流的用法#
捕获模式会自动传播到与捕获流同步的任何流。在捕获中,您可以通过向不同流发出调用来暴露并行性,但整体流依赖 DAG 必须在捕获开始后从初始捕获流分支出去,并在捕获结束前重新加入初始流。
with torch.cuda.graph(g):
# at context manager entrance, torch.cuda.current_stream()
# is the initial capturing stream
# INCORRECT (does not branch out from or rejoin initial stream)
with torch.cuda.stream(s):
cuda_work()
# CORRECT:
# branches out from initial stream
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
cuda_work()
# rejoins initial stream before capture ends
torch.cuda.current_stream().wait_stream(s)
注意
为了避免高级用户在 nsight systems 或 nvprof 中查看重放时感到困惑:与急切执行不同,图将捕获中的非平凡流 DAG 解释为提示,而不是命令。在重放期间,图可能会将独立的运算重新组织到不同的流上,或者以不同的顺序将它们排入队列(同时尊重您原始 DAG 的总体依赖关系)。
与 DistributedDataParallel 的用法#
NCCL < 2.9.6#
早于 2.9.6 的 NCCL 版本不允许捕获集体通信。您必须使用 部分网络捕获,这将 allreduce 操作推迟到图形化反向传播部分之外进行。
在用 DDP 包装网络之前,在可图形化网络部分上调用 make_graphed_callables()
。
NCCL >= 2.9.6#
NCCL 2.9.6 或更高版本允许在图中进行集体通信。捕获 整个反向传播 的方法是一个可行的选择,但需要三个设置步骤。
禁用 DDP 的内部异步错误处理。
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" torch.distributed.init_process_group(...)
在完全反向捕获之前,DDP 必须在侧流上下文中构建。
with torch.cuda.stream(s): model = DistributedDataParallel(model)
您的预热必须在捕获前至少运行 11 次启用 DDP 的急切迭代。
图内存管理#
捕获的图每次重播时都在相同的虚拟地址上操作。如果 PyTorch 释放了内存,后续的重播可能会发生非法内存访问。如果 PyTorch 将内存重新分配给新的张量,重播可能会损坏这些张量所看到的值。因此,图使用的虚拟地址必须在重播期间为图保留。PyTorch 缓存分配器通过检测何时正在进行捕获并从图私有内存池满足捕获的分配来实现此目的。私有池在其 CUDAGraph
对象和所有在捕获期间创建的张量超出范围之前保持活动状态。
私有池是自动维护的。默认情况下,分配器为每个捕获创建一个单独的私有池。如果您捕获多个图,这种保守的方法可确保图重放永远不会损坏彼此的值,但有时会不必要地浪费内存。
跨捕获共享内存#
为了节省私有池中存储的内存,torch.cuda.graph
和 torch.cuda.make_graphed_callables()
可选地允许不同的捕获共享相同的私有池。如果您知道一组图将始终按照它们捕获的相同顺序重播,并且永不并发重播,则它们共享一个私有池是安全的。
torch.cuda.graph
的 pool
参数是使用特定私有池的提示,可用于跨图共享内存,如所示。
g1 = torch.cuda.CUDAGraph()
g2 = torch.cuda.CUDAGraph()
# (create static inputs for g1 and g2, run warmups of their workloads...)
# Captures g1
with torch.cuda.graph(g1):
static_out_1 = g1_workload(static_in_1)
# Captures g2, hinting that g2 may share a memory pool with g1
with torch.cuda.graph(g2, pool=g1.pool()):
static_out_2 = g2_workload(static_in_2)
static_in_1.copy_(real_data_1)
static_in_2.copy_(real_data_2)
g1.replay()
g2.replay()
对于 torch.cuda.make_graphed_callables()
,如果您想将多个可调用对象图形化,并且知道它们将始终以相同顺序运行(且永不并发),则以它们在实际工作负载中运行的相同顺序将它们作为元组传递,make_graphed_callables()
将使用共享私有池捕获它们的图。
如果在实际工作负载中,您的可调用对象将以偶尔变化的顺序运行,或者它们将并发运行,则不允许将它们作为元组传递给 make_graphed_callables()
的单个调用。相反,您必须为每个可调用对象单独调用 make_graphed_callables()
。