注意
转到末尾 下载完整的示例代码。
Autograd 保存张量的钩子#
创建时间:2021 年 11 月 03 日 | 上次更新时间:2024 年 8 月 27 日 | 上次验证时间:未验证
PyTorch 通常使用反向传播来计算梯度。但是,某些操作需要保存中间结果才能执行反向传播。本教程将介绍如何保存/检索这些张量,以及如何定义钩子来控制打包/解包过程。
本教程假定您熟悉反向传播在理论上的工作方式。如果不是,请先阅读此内容。
保存的张量#
训练模型通常比运行模型进行推理消耗更多的内存。 广义地说,可以说这是因为“PyTorch需要保存计算图,该图需要调用 backward
”,因此需要额外的内存使用。本教程的一个目标是微调这种理解。
事实上,图本身有时不会消耗更多的内存,因为它从不复制任何张量。但是,该图可以保留对原本会超出范围的张量的引用:这些张量被称为保存的张量。
为什么训练模型(通常)比评估模型需要更多的内存?#
我们从一个简单的例子开始: \(y = a \cdot b\),我们知道 \(y\) 相对于 \(a\) 和 \(b\) 的梯度
import torch
a = torch.randn(5, requires_grad=True)
b = torch.ones(5, requires_grad=True)
y = a * b
使用 torchviz,我们可以可视化计算图
在此示例中,PyTorch 保存中间值 \(a\) 和 \(b\),以便在向后传播期间计算梯度。
可以通过查找以 _saved
为前缀的 y
的 grad_fn
的属性来访问(出于调试目的)这些中间值(上图中为橙色)。
tensor([-0.2844, -0.2342, -1.0392, 0.9483, 1.2858], requires_grad=True)
tensor([1., 1., 1., 1., 1.], requires_grad=True)
随着计算图深度的增加,它将存储更多的保存的张量。与此同时,如果不是因为该图,这些张量就会超出范围。

在上面的示例中,在没有梯度的情况下执行只会将 x
和 y
保留在作用域中。但该图还会额外存储 f(x)
和 f(f(x))
。因此,在训练期间运行前向传播在内存使用方面比评估期间(更准确地说,当不需要 autograd 时)更昂贵。
打包/解包的概念#
回到第一个例子: y.grad_fn._saved_self
和 y.grad_fn._saved_other
分别指向原始张量对象 a
和 b
。
True
True
但是,情况可能并非总是如此。
True
False
在底层,PyTorch 已经打包和解包了张量 y
,以防止引用循环。
作为经验法则,您不应该依赖于访问为反向传播保存的张量会产生与原始张量相同的张量对象。但是,它们将共享相同的存储。
保存的张量钩子#
PyTorch 提供了一个 API 来控制应如何打包/解包保存的张量。
def pack_hook(x):
print("Packing", x)
return x
def unpack_hook(x):
print("Unpacking", x)
return x
a = torch.ones(5, requires_grad=True)
b = torch.ones(5, requires_grad=True) * 2
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
y = a * b
y.sum().backward()
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
每次操作保存张量以进行反向传播时,都会调用 pack_hook
函数。然后,pack_hook
的输出存储在计算图中,而不是原始张量中。 unpack_hook
使用该返回值来计算新张量,该张量是在反向传播过程中实际使用的张量。通常,您希望 unpack_hook(pack_hook(t))
等于 t
。
需要注意的一件事是,pack_hook
的输出可以是任何 Python 对象,只要 unpack_hook
可以从中派生出具有正确值的张量即可。
一些非常规的例子#
首先,一些愚蠢的例子来说明可能发生的事情,但您可能永远不想这样做。
返回一个 int
#
返回 Python 列表的索引 相对无害,但实用性值得商榷
返回一个元组#
返回一些张量和一个解包函数 在当前形式下不太可能有用
def pack(x):
delta = torch.randn(*x.size())
return x - delta, lambda x: x + delta
def unpack(packed):
x, f = packed
return f(x)
x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
y = x * x
y.sum().backward()
assert(torch.allclose(x.grad, 2 * x))
返回一个 str
#
返回张量的 __repr__
可能永远不会这样做
虽然这些例子在实践中没有用处,但它们说明 pack_hook
的输出实际上可以是任何 Python 对象,只要它包含足够的信息来检索原始张量的内容即可。在接下来的部分中,我们将重点介绍更有用的应用程序。
将张量保存到 CPU#
通常,计算图中涉及的张量位于 GPU 上。在图中保留对这些张量的引用是导致大多数模型在训练期间耗尽 GPU 内存的原因,而它们在评估期间运行良好。
钩子提供了一种非常简单的方法来实现这一点。
def pack_hook(x):
return (x.device, x.cpu())
def unpack_hook(packed):
device, tensor = packed
return tensor.to(device)
x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
y = x * x
y.sum().backward()
torch.allclose(x.grad, (2 * x))
True
事实上,PyTorch 提供了一个 API 来方便地使用这些钩子(以及使用固定内存的能力)。
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.randn(5))
def forward(self, x):
with torch.autograd.graph.save_on_cpu(pin_memory=True):
# some computation
return self.w * x
x = torch.randn(5)
model = Model()
loss = model(x).sum()
loss.backward()
实际上,在 A100 GPU 上,对于批量大小为 256 的 ResNet-152,这对应于 GPU 内存使用量从 48GB 减少到 5GB,但代价是速度降低了 6 倍。
当然,您可以通过仅将网络的某些部分保存到 CPU 来调节这种权衡。
例如,您可以定义一个特殊的 nn.Module
,它可以包装任何模块并将其张量保存到 CPU。
class SaveToCpu(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
with torch.autograd.graph.save_on_cpu(pin_memory=True):
return self.module(*args, **kwargs)
model = nn.Sequential(
nn.Linear(10, 100),
SaveToCpu(nn.Linear(100, 100)),
nn.Linear(100, 10),
)
x = torch.randn(10)
loss = model(x).sum()
loss.backward()
将张量保存到磁盘#
同样,您可能希望将这些张量保存到磁盘。同样,这可以通过这些钩子来实现。
一个幼稚的版本如下所示。
# Naive version - HINT: Don't do this
import uuid
tmp_dir = "temp"
def pack_hook(tensor):
name = os.path.join(tmp_dir, str(uuid.uuid4()))
torch.save(tensor, name)
return name
def unpack_hook(name):
return torch.load(name, weights_only=True)
上面的代码不好的原因是我们在磁盘上泄漏文件,并且它们永远不会被清除。修复此问题并不像看起来那么简单。
# Incorrect version - HINT: Don't do this
import uuid
import os
import tempfile
tmp_dir_obj = tempfile.TemporaryDirectory()
tmp_dir = tmp_dir_obj.name
def pack_hook(tensor):
name = os.path.join(tmp_dir, str(uuid.uuid4()))
torch.save(tensor, name)
return name
def unpack_hook(name):
tensor = torch.load(name, weights_only=True)
os.remove(name)
return tensor
上面的代码不起作用的原因是 unpack_hook
可以被多次调用。如果我们第一次解包时删除文件,那么当第二次访问保存的张量时,该文件将不可用,这将引发错误。
x = torch.ones(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
y = x.pow(2)
print(y.grad_fn._saved_self)
try:
print(y.grad_fn._saved_self)
print("Double access succeeded!")
except:
print("Double access failed!")
tensor([1., 1., 1., 1., 1.], requires_grad=True)
Double access failed!
为了解决这个问题,我们可以编写一个利用 PyTorch 在不再需要时自动释放(删除)保存的数据这一事实的钩子版本。
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name, weights_only=True)
当我们调用 backward
时,pack_hook
的输出将被删除,这将导致文件被删除,因此我们不再泄漏文件。
然后,可以以以下方式在您的模型中使用它
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class SaveToDisk(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
return self.module(*args, **kwargs)
net = nn.DataParallel(SaveToDisk(Model()))
在最后一个示例中,我们还演示了如何过滤应保存哪些张量(此处为元素数量大于 1000 的张量),以及如何将此功能与 nn.DataParallel
结合使用。
如果您已经走到这一步,恭喜!您现在知道如何使用保存的张量钩子,以及它们如何在一些场景中有用,以权衡内存以获得计算能力。
脚本的总运行时间: (0 分钟 0.270 秒)