注意
转到末尾下载完整的示例代码。
如何通过将优化器步骤融入反向传播来节省内存#
创建于:2023年10月02日 | 最后更新:2024年01月16日 | 最后验证:2024年11月05日
您好!本教程旨在展示一种减少训练循环内存占用的方法,即减少*梯度*所占用的内存。假设您有一个模型,并且您正在寻找优化内存的方法以避免内存不足
(OOM)错误,或者仅仅是为了更好地利用您的 GPU。那么,您_可能_很幸运(如果梯度占用了您内存的一部分,并且您不需要进行梯度累积)。我们将探讨以下内容:
在训练或微调循环中什么会占用内存,
如何捕获和可视化内存快照以确定瓶颈,
新的
Tensor.register_post_accumulate_grad_hook(hook)
API,以及最后,如何用 10 行代码将所有内容整合在一起以实现内存节省。
要运行本教程,您需要:
PyTorch 2.1.0 或更新版本,并安装
torchvision
如果您想在本地运行内存可视化,需要 1 个 CUDA GPU。否则,此技术在任何设备上都能获得类似的好处。
让我们从导入所需的模块和模型开始。我们将使用 torchvision 中的一个视觉变换器模型,但您也可以随时用自己的模型替换。我们还将使用 torch.optim.Adam
作为我们的优化器,但同样,您也可以随时用自己的优化器替换。
import torch
from torchvision import models
from pickle import dump
model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
Downloading: "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/vit_l_16-852ce7e3.pth
0%| | 0.00/1.13G [00:00<?, ?B/s]
1%|▏ | 16.5M/1.13G [00:00<00:16, 72.2MB/s]
3%|▎ | 32.9M/1.13G [00:00<00:16, 69.9MB/s]
4%|▍ | 49.2M/1.13G [00:00<00:14, 77.8MB/s]
6%|▌ | 64.9M/1.13G [00:00<00:13, 87.6MB/s]
6%|▋ | 73.6M/1.13G [00:00<00:14, 79.5MB/s]
7%|▋ | 82.0M/1.13G [00:01<00:16, 70.5MB/s]
8%|▊ | 96.8M/1.13G [00:01<00:14, 75.5MB/s]
9%|▉ | 104M/1.13G [00:01<00:15, 73.3MB/s]
10%|▉ | 115M/1.13G [00:01<00:16, 64.6MB/s]
11%|█ | 129M/1.13G [00:01<00:13, 82.1MB/s]
12%|█▏ | 138M/1.13G [00:01<00:16, 66.9MB/s]
13%|█▎ | 153M/1.13G [00:02<00:12, 83.3MB/s]
14%|█▍ | 164M/1.13G [00:02<00:13, 78.4MB/s]
15%|█▌ | 179M/1.13G [00:02<00:12, 81.9MB/s]
16%|█▌ | 187M/1.13G [00:02<00:13, 74.1MB/s]
17%|█▋ | 196M/1.13G [00:02<00:15, 67.2MB/s]
17%|█▋ | 203M/1.13G [00:02<00:16, 60.8MB/s]
18%|█▊ | 213M/1.13G [00:03<00:17, 57.7MB/s]
19%|█▉ | 222M/1.13G [00:03<00:19, 50.4MB/s]
20%|█▉ | 229M/1.13G [00:03<00:19, 49.7MB/s]
20%|██ | 234M/1.13G [00:03<00:22, 43.2MB/s]
21%|██ | 244M/1.13G [00:03<00:19, 50.4MB/s]
21%|██▏ | 249M/1.13G [00:04<00:21, 44.1MB/s]
23%|██▎ | 262M/1.13G [00:04<00:16, 57.8MB/s]
24%|██▍ | 277M/1.13G [00:04<00:12, 71.8MB/s]
24%|██▍ | 284M/1.13G [00:04<00:14, 63.2MB/s]
25%|██▌ | 295M/1.13G [00:04<00:14, 63.0MB/s]
27%|██▋ | 311M/1.13G [00:04<00:12, 70.7MB/s]
28%|██▊ | 328M/1.13G [00:05<00:13, 64.1MB/s]
29%|██▉ | 339M/1.13G [00:05<00:14, 59.4MB/s]
30%|██▉ | 344M/1.13G [00:05<00:17, 49.7MB/s]
31%|███ | 359M/1.13G [00:05<00:12, 66.6MB/s]
32%|███▏ | 367M/1.13G [00:06<00:16, 51.4MB/s]
32%|███▏ | 377M/1.13G [00:06<00:16, 50.1MB/s]
34%|███▎ | 392M/1.13G [00:06<00:13, 61.7MB/s]
34%|███▍ | 399M/1.13G [00:06<00:14, 55.4MB/s]
35%|███▌ | 408M/1.13G [00:06<00:12, 63.3MB/s]
36%|███▌ | 415M/1.13G [00:06<00:13, 59.9MB/s]
37%|███▋ | 426M/1.13G [00:07<00:12, 63.3MB/s]
37%|███▋ | 432M/1.13G [00:07<00:15, 49.0MB/s]
38%|███▊ | 442M/1.13G [00:07<00:14, 50.8MB/s]
40%|███▉ | 459M/1.13G [00:07<00:11, 64.2MB/s]
41%|████ | 474M/1.13G [00:07<00:10, 67.4MB/s]
41%|████▏ | 480M/1.13G [00:07<00:11, 61.0MB/s]
42%|████▏ | 490M/1.13G [00:08<00:10, 66.7MB/s]
43%|████▎ | 497M/1.13G [00:08<00:12, 55.7MB/s]
44%|████▎ | 506M/1.13G [00:08<00:11, 58.8MB/s]
44%|████▍ | 512M/1.13G [00:08<00:12, 53.0MB/s]
45%|████▌ | 524M/1.13G [00:08<00:10, 61.3MB/s]
46%|████▋ | 539M/1.13G [00:08<00:08, 75.6MB/s]
47%|████▋ | 547M/1.13G [00:09<00:09, 67.4MB/s]
48%|████▊ | 556M/1.13G [00:09<00:10, 62.0MB/s]
48%|████▊ | 563M/1.13G [00:09<00:10, 57.9MB/s]
49%|████▉ | 574M/1.13G [00:09<00:09, 64.2MB/s]
51%|█████ | 590M/1.13G [00:09<00:07, 79.6MB/s]
52%|█████▏ | 606M/1.13G [00:09<00:07, 77.7MB/s]
53%|█████▎ | 621M/1.13G [00:10<00:06, 85.0MB/s]
54%|█████▍ | 630M/1.13G [00:10<00:07, 75.3MB/s]
55%|█████▌ | 639M/1.13G [00:10<00:08, 65.8MB/s]
56%|█████▋ | 655M/1.13G [00:10<00:07, 70.0MB/s]
57%|█████▋ | 662M/1.13G [00:10<00:08, 59.4MB/s]
58%|█████▊ | 672M/1.13G [00:10<00:07, 65.5MB/s]
59%|█████▉ | 687M/1.13G [00:11<00:06, 81.0MB/s]
60%|█████▉ | 695M/1.13G [00:11<00:07, 64.5MB/s]
60%|██████ | 702M/1.13G [00:11<00:07, 63.7MB/s]
61%|██████ | 709M/1.13G [00:11<00:08, 55.0MB/s]
62%|██████▏ | 720M/1.13G [00:11<00:07, 65.8MB/s]
63%|██████▎ | 727M/1.13G [00:11<00:07, 58.8MB/s]
63%|██████▎ | 737M/1.13G [00:12<00:07, 59.5MB/s]
64%|██████▍ | 743M/1.13G [00:12<00:08, 54.4MB/s]
65%|██████▍ | 754M/1.13G [00:12<00:07, 61.0MB/s]
66%|██████▋ | 769M/1.13G [00:12<00:05, 73.6MB/s]
67%|██████▋ | 776M/1.13G [00:12<00:06, 65.6MB/s]
68%|██████▊ | 786M/1.13G [00:12<00:05, 72.6MB/s]
68%|██████▊ | 793M/1.13G [00:13<00:07, 54.7MB/s]
69%|██████▉ | 802M/1.13G [00:13<00:06, 62.6MB/s]
70%|██████▉ | 809M/1.13G [00:13<00:06, 57.6MB/s]
71%|███████ | 819M/1.13G [00:13<00:05, 60.4MB/s]
71%|███████ | 826M/1.13G [00:13<00:05, 61.4MB/s]
72%|███████▏ | 836M/1.13G [00:13<00:05, 68.0MB/s]
73%|███████▎ | 852M/1.13G [00:13<00:04, 74.3MB/s]
74%|███████▍ | 859M/1.13G [00:14<00:04, 64.5MB/s]
75%|███████▍ | 867M/1.13G [00:14<00:04, 63.8MB/s]
75%|███████▌ | 873M/1.13G [00:14<00:05, 57.1MB/s]
76%|███████▌ | 884M/1.13G [00:14<00:04, 67.2MB/s]
77%|███████▋ | 891M/1.13G [00:14<00:05, 56.3MB/s]
77%|███████▋ | 900M/1.13G [00:14<00:04, 60.4MB/s]
78%|███████▊ | 906M/1.13G [00:14<00:04, 58.3MB/s]
79%|███████▉ | 916M/1.13G [00:15<00:03, 67.7MB/s]
79%|███████▉ | 923M/1.13G [00:15<00:05, 49.3MB/s]
80%|████████ | 934M/1.13G [00:15<00:04, 55.4MB/s]
82%|████████▏ | 949M/1.13G [00:15<00:03, 69.9MB/s]
82%|████████▏ | 956M/1.13G [00:15<00:03, 63.7MB/s]
83%|████████▎ | 967M/1.13G [00:15<00:03, 64.5MB/s]
85%|████████▍ | 982M/1.13G [00:16<00:02, 70.0MB/s]
85%|████████▌ | 989M/1.13G [00:16<00:03, 58.8MB/s]
86%|████████▌ | 995M/1.13G [00:16<00:03, 46.3MB/s]
86%|████████▌ | 0.98G/1.13G [00:16<00:04, 40.0MB/s]
87%|████████▋ | 0.99G/1.13G [00:16<00:02, 54.2MB/s]
89%|████████▊ | 1.01G/1.13G [00:17<00:01, 71.2MB/s]
89%|████████▉ | 1.01G/1.13G [00:17<00:01, 69.4MB/s]
90%|█████████ | 1.02G/1.13G [00:17<00:01, 72.8MB/s]
92%|█████████▏| 1.04G/1.13G [00:17<00:01, 84.1MB/s]
93%|█████████▎| 1.06G/1.13G [00:17<00:00, 89.2MB/s]
94%|█████████▍| 1.07G/1.13G [00:17<00:00, 97.8MB/s]
95%|█████████▌| 1.08G/1.13G [00:17<00:00, 84.7MB/s]
96%|█████████▌| 1.09G/1.13G [00:18<00:00, 82.2MB/s]
97%|█████████▋| 1.10G/1.13G [00:18<00:00, 87.0MB/s]
98%|█████████▊| 1.11G/1.13G [00:18<00:00, 72.2MB/s]
99%|█████████▊| 1.12G/1.13G [00:18<00:00, 60.1MB/s]
99%|█████████▉| 1.12G/1.13G [00:18<00:00, 54.4MB/s]
100%|██████████| 1.13G/1.13G [00:18<00:00, 64.6MB/s]
现在让我们定义我们典型的训练循环。在训练时您应该使用真实的图像,但为了本教程的目的,我们传入伪造的输入,而不必担心加载任何实际数据。
IMAGE_SIZE = 224
def train(model, optimizer):
# create our fake image input: tensor shape is batch_size, channels, height, width
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
# call our forward and backward
loss = model.forward(fake_image)
loss.sum().backward()
# optimizer update
optimizer.step()
optimizer.zero_grad()
训练期间的内存使用情况#
我们即将查看一些内存快照,因此我们应该准备好正确分析它们。通常,训练内存包括:
模型参数(大小为 P)
为反向传播保存的激活(大小为 A)
梯度,其大小与模型参数相同,因此大小 G = P。
优化器状态,与参数大小成正比。在这种情况下,Adam 的状态需要 2 倍的模型参数,因此大小 O = 2P。
中间张量,在整个计算过程中分配。我们暂时不担心它们,因为它们通常很小且是短暂的。
捕获和可视化内存快照#
让我们获取一个内存快照!在您的代码运行时,请思考您期望的 CUDA 内存时间线会是什么样子。
# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')
# train 3 steps
for _ in range(3):
train(model, optimizer)
# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot.pickle", "wb") as f:
dump(s, f)
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
现在,通过拖放 snapshot.pickle
文件,在 CUDA 内存可视化工具 https://pytorch.ac.cn/memory_viz 中打开快照。内存时间线是否与您的预期相符?

在训练步骤开始之前,模型参数已经加载到内存中,所以我们一开始就看到一块内存被用于权重。当我们开始前向传播时,内存会逐渐为激活(即我们为在反向传播中计算梯度而保存的张量)分配。一旦我们开始反向传播,激活会逐渐被释放,而梯度的内存则开始增加。
最后,当优化器启动时,其状态将是惰性初始化的,所以我们应该只在第一个训练循环的优化器步骤中看到优化器状态内存逐渐增加。在未来的循环中,优化器内存将保持不变并进行原地更新。然后,在每个训练循环结束时调用 zero_grad
时,梯度内存会相应地被释放。
在这个训练循环中,内存瓶颈在哪里?换句话说,峰值内存在哪里?
峰值内存使用发生在优化器步骤期间!请注意,此时内存包含约 1.2GB 的参数、约 1.2GB 的梯度和约 2.4GB=2*1.2GB 的优化器状态,这符合预期。最后的约 1.2GB 来自 Adam 优化器需要用于中间变量的内存,总计峰值内存约为 6GB。从技术上讲,如果您设置 Adam(model.parameters(), foreach=False)
,您可以消除对最后 1.2GB 优化器中间变量的需求,这会以运行时换取内存。如果关闭 foreach
运行时优化对您来说已经足够节省内存,那很好,但如果您好奇本教程如何能帮助您做得更好,请继续阅读!通过我们即将介绍的技术,我们将通过消除对约 1.2GB 的**梯度内存**以及**优化器中间变量内存**的需求来减少峰值内存。现在,您期望新的峰值内存会是多少?答案将在*下一个*快照中揭晓。
免责声明:此技术并**不**适用于所有情况#
在我们过于兴奋之前,我们必须考虑这项技术是否适用于*您*的用例。这不是一个万能的解决方案!将优化器步骤融入反向传播的技术仅旨在减少*梯度*内存(并顺带减少优化器中间变量内存)。因此,梯度占用的内存越大,内存减少的效果就越显著。在我们上面的例子中,梯度占用了内存的 20%,这是相当可观的!
这可能不适用于您的情况,例如,如果您的权重已经很小(比如,因为应用了 LoRa),那么梯度在您的训练循环中就不会占用太多空间,带来的收益也就没那么令人兴奋了。在这种情况下,您应该首先尝试其他技术,如激活检查点、分布式训练、量化或减小批量大小。然后,当梯度再次成为瓶颈的一部分时,再回到本教程!
还在这里吗?很好,让我们介绍一下 Tensor 上的新 register_post_accumulate_grad_hook(hook)
API。
Tensor.register_post_accumulate_grad_hook(hook)
API 和我们的技术#
我们的技术依赖于在 backward()
期间不必保存梯度。相反,一旦一个梯度被累积,我们将立即对相应的参数应用优化器,并完全丢弃该梯度!这消除了在优化器步骤之前一直持有一个大梯度缓冲区的需要。
那么我们如何才能解锁这种更积极地应用优化器的行为呢?在我们的 2.1 版本中,我们添加了一个新的 API torch.Tensor.register_post_accumulate_grad_hook()
,它允许我们在一个张量的 .grad
字段被累积后,为其添加一个钩子。我们将把优化器步骤封装在这个钩子中。怎么做呢?
如何用 10 行代码将所有内容整合在一起#
还记得我们一开始的模型和优化器设置吗?我将它们注释掉,这样我们就不会花费资源重新运行代码。
model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
# for every parameter so we could reference them in our hook.
optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}
# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
def optimizer_hook(parameter) -> None:
optimizer_dict[parameter].step()
optimizer_dict[parameter].zero_grad()
# Register the hook onto every parameter
for p in model.parameters():
p.register_post_accumulate_grad_hook(optimizer_hook)
# Now remember our previous ``train()`` function? Since the optimizer has been
# fused into the backward, we can remove the optimizer step and zero_grad calls.
def train(model):
# create our fake image input: tensor shape is batch_size, channels, height, width
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
# call our forward and backward
loss = model.forward(fake_image)
loss.sum().backward()
# optimizer update --> no longer needed!
# optimizer.step()
# optimizer.zero_grad()
在我们的示例模型中,这大约需要 10 行的更改,这很简洁。然而,对于真实模型来说,将优化器换成优化器字典可能是一个相当大的侵入性改动,特别是对于那些使用 `LRScheduler` 或在整个训练周期中操作优化器配置的用户。将此 API 与这些更改结合使用会更复杂,并且可能需要将更多配置移至全局状态,但这并非不可能。也就是说,PyTorch 的下一步是使这个 API 更容易与您已经习惯的 LRScheduler 和其他功能一起使用。
但让我回到说服您这项技术是值得的。我们将咨询我们的朋友,内存快照。
# delete optimizer memory from before to get a clean slate for the next
# memory snapshot
del optimizer
# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')
# train 3 steps. note that we no longer pass the optimizer into train()
for _ in range(3):
train(model)
# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot-opt-in-bwd.pickle", "wb") as f:
dump(s, f)
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
是的,花点时间将您的快照拖到 CUDA 内存可视化工具中。

- 几个主要观察结果:
不再有优化器步骤!是的……我们已经把它融入了反向传播。
同样,反向传播的时间更长,并且有更多用于中间变量的随机分配。这是预料之中的,因为优化器步骤需要中间变量。
最重要的是!峰值内存降低了!现在大约是 4GB(我希望这与您之前的预期非常接近)。
请注意,与之前相比,不再有大块内存分配给梯度,这节省了约 1.2GB 的内存。相反,我们通过尽可能提前移动优化器步骤,在每个梯度计算后非常快地释放了它们。太棒了!顺便说一句,另外约 1.2GB 的内存节省来自于将优化器分解为每个参数的优化器,因此中间变量也相应地缩小了。这个细节比梯度内存的节省*不那么重要*,因为您只需关闭 foreach=False
就可以获得优化器中间变量的节省,而无需使用此技术。
您可能会正确地想:如果我们节省了 2.4GB 的内存,为什么峰值内存不是 6GB - 2.4GB = 3.6GB?嗯,峰值移动了!峰值现在位于反向传播步骤的开始附近,此时我们内存中仍有激活,而之前,峰值是在优化器步骤期间,此时激活已被释放。因此,~4.0GB - ~3.6GB 的 ~0.4GB 差异是由于激活内存造成的。可以想象,这项技术可以与激活检查点结合使用,以获得更多的内存收益。
结论#
在本教程中,我们学习了通过新的 Tensor.register_post_accumulate_grad_hook()
API 将优化器融入反向传播步骤的内存节省技术,以及*何时*应用此技术(当梯度内存显著时)。在此过程中,我们还了解了内存快照,这对于内存优化通常很有用。
脚本总运行时间: (0 分 25.141 秒)