Autograd 机制#
创建日期: 2017 年 1 月 16 日 | 最后更新: 2025 年 6 月 16 日
本笔记将概述 autograd 的工作原理及其记录操作的方式。了解所有这些并非绝对必要,但我们建议您熟悉它,因为这将帮助您编写更高效、更简洁的程序,并有助于您调试。
Autograd 如何编码历史记录#
Autograd 是一个反向自动微分系统。从概念上讲,autograd 会记录一个图,记录下创建数据的所有操作,为您提供一个有向无环图,其中叶子是输入张量,根是输出张量。通过从根到叶跟踪此图,您可以自动使用链式法则计算梯度。
在内部,autograd 将此图表示为 Function 对象的图(实际上是表达式),这些对象可以被 apply() 以计算图的评估结果。在计算前向传播时,autograd 同时执行请求的计算,并构建一个表示计算梯度的函数图(每个 torch.Tensor 的 .grad_fn 属性是此图的入口点)。当前向传播完成后,我们在反向传播中评估此图以计算梯度。
需要注意的一点是,该图在每次迭代时都会从头开始重建,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代时更改图的整体形状和大小。您无需在启动训练前编码所有可能的路径——您运行的就是您所微分的。
保存的张量#
某些操作需要在前向传播期间保存中间结果才能执行反向传播。例如,函数 会保存输入 以计算梯度。
在定义自定义 Python Function 时,您可以使用 save_for_backward() 在前向传播期间保存张量,并使用 saved_tensors 在反向传播期间检索它们。有关更多信息,请参阅 扩展 PyTorch。
对于 PyTorch 定义的操作(例如 torch.pow()),张量会根据需要自动保存。您可以(出于教育或调试目的)通过查找以 _saved 为前缀的属性,来探索特定 grad_fn 保存了哪些张量。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
在前面的代码中,y.grad_fn._saved_self 指代与 x 相同的 Tensor 对象。但这并不总是成立的。例如
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
在底层,为了防止引用循环,PyTorch 在保存时对张量进行了*打包*,并在读取时将其*解包*到另一个张量中。在这里,您从访问 y.grad_fn._saved_result 中获得的张量是与 y 不同的张量对象(但它们仍然共享相同的存储区)。
一个张量是否会被打包成另一个张量对象,取决于它是否是其自身的 grad_fn 的输出,这是一个可能更改的实现细节,用户不应依赖它。
您可以通过 保存张量的钩子 控制 PyTorch 执行打包/解包的方式。
不可微函数的梯度#
使用自动微分的梯度计算仅在所使用的每个基本函数可微时才有效。不幸的是,我们在实践中使用的许多函数不具备此属性(例如 relu 或 sqrt 在 0 处)。为了尽量减少不可微函数的影响,我们按以下顺序定义基本操作的梯度:
如果函数是可微的,因此在当前点存在梯度,则使用该梯度。
如果函数是凸的(至少在局部),则使用最小范数的次梯度。
如果函数是凹的(至少在局部),则使用最小范数的上梯度(考虑 -f(x) 并应用前一点)。
如果函数已定义,则通过连续性在当前点定义梯度(注意这里可能出现
inf,例如对于sqrt(0))。如果可能存在多个值,则任意选择一个。如果函数未定义(例如
sqrt(-1)、log(-1)或在输入为NaN时大多数函数),则用作梯度的值是任意的(我们也可以引发错误,但这不保证)。大多数函数将使用NaN作为梯度,但出于性能原因,一些函数将使用其他值(例如log(-1))。如果函数不是确定性映射(即它不是一个数学函数),它将被标记为不可微。如果在
no_grad环境之外用于需要梯度的张量上,这将在反向传播时导致错误。
Autograd 中的除以零#
当在 PyTorch 中执行除以零操作时(例如 x / 0),前向传播将按照 IEEE-754 浮点算术产生 inf 值。虽然这些 inf 值可以在计算最终损失之前通过掩码移除(例如通过索引或掩码),但 autograd 系统仍然会跟踪并通过完整的计算图进行微分,包括除以零操作。
在反向传播期间,这可能导致有问题的梯度表达式。例如
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x / div # Results in [inf, 1]
mask = div != 0 # [False, True]
loss = y[mask].sum()
loss.backward()
print(x.grad) # [nan, 1], not [0, 1]
在此示例中,即使我们只使用被掩码的输出(其中排除了除以零的操作),autograd 仍然通过完整的计算图计算梯度,包括除以零的操作。这会导致被掩码元素产生 nan 梯度,从而可能导致训练不稳定。
为避免此问题,有几种推荐的方法:
在除法前进行掩码
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
mask = div != 0
safe = torch.zeros_like(x)
safe[mask] = x[mask] / div[mask]
loss = safe.sum()
loss.backward() # Produces safe gradients [0, 1]
使用 MaskedTensor(实验性 API)
from torch.masked import as_masked_tensor
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x / div
mask = div != 0
loss = as_masked_tensor(y, mask).sum()
loss.backward() # Cleanly handles "undefined" vs "zero" gradients
关键原则是防止将除以零操作记录在计算图中,而不是事后掩盖其结果。这确保了 autograd 只通过有效操作计算梯度。
在使用可能产生 inf 或 nan 值的操作时,记住此行为非常重要,因为掩盖输出并不能阻止有问题的梯度被计算出来。
局部禁用梯度计算#
Python 提供了几种机制可以在本地禁用梯度计算:
要跨整个代码块禁用梯度,有诸如 no-grad 模式和推理模式之类的上下文管理器。要从梯度计算中更精细地排除子图,可以通过设置张量的 requires_grad 字段来实现。
下面,除了讨论上述机制外,我们还描述了评估模式(nn.Module.eval()),这是一种不用于禁用梯度计算的方法,但由于其名称,经常与前三种方法混淆。
设置 requires_grad#
requires_grad 是一个标志,默认为 false(*除非包装在 nn.Parameter *中),它允许对子图进行精细控制,将其排除在梯度计算之外。它在前向和反向传播中都有效。
在前向传播期间,只有当至少一个输入张量需要梯度时,操作才会被记录在反向图中。在反向传播(.backward())期间,只有设置了 requires_grad=True 的叶子张量才会将梯度累积到它们的 .grad 字段中。
需要注意的是,尽管每个张量都有此标志,但*设置*它仅对叶子张量有意义(即没有 grad_fn 的张量,例如 nn.Module 的参数)。非叶子张量(即具有 grad_fn 的张量)是与反向图相关联的张量。因此,需要它们的梯度作为中间结果来计算需要梯度的叶子张量的梯度。根据此定义,很明显所有非叶子张量的 require_grad 都会自动设置为 True。
设置 requires_grad 应是控制模型哪些部分参与梯度计算的主要方式,例如,如果您需要在模型微调期间冻结预训练模型的部分。
要冻结模型的部分,只需对不希望更新的参数应用 .requires_grad_(False)。如上所述,由于使用这些参数作为输入的计算不会在前向传播中被记录,因此在反向传播中它们的 .grad 字段不会被更新,因为它们一开始就不在反向图中,正如预期的那样。
由于这是一个非常常见的模式,requires_grad 也可以通过模块级别的 nn.Module.requires_grad_() 来设置。当应用于模块时,.requires_grad_() 会影响模块的所有参数(这些参数默认设置为 requires_grad=True)。
梯度模式#
除了设置 requires_grad 之外,还有三种可以从 Python 中选择的梯度模式,它们可能会影响 PyTorch 中计算的内部处理方式:默认模式(梯度模式)、无梯度模式和推理模式,所有这些都可以通过上下文管理器和装饰器进行切换。
模式 |
将操作排除在反向图记录之外 |
跳过额外的 autograd 跟踪开销 |
在模式启用期间创建的张量之后可用于梯度模式的计算 |
示例 |
|---|---|---|---|---|
默认 |
✓ |
前向传播 |
||
无梯度 |
✓ |
✓ |
优化器更新 |
|
推理 |
✓ |
✓ |
数据处理、模型评估 |
默认模式(梯度模式)#
“默认模式”是我们隐含处于的模式,当没有启用其他模式(如 no-grad 和 inference 模式)时。与“no-grad 模式”相对,默认模式有时也称为“grad 模式”。
关于默认模式最重要的一点是,它是唯一一个 requires_grad 生效的模式。在另外两种模式中,requires_grad 始终被覆盖为 False。
无梯度模式#
no-grad 模式下的计算表现为所有输入都不需要梯度。换句话说,no-grad 模式下的计算永远不会被记录在反向图中,即使存在设置了 require_grad=True 的输入。
当您需要执行不应被 autograd 记录的操作,但仍希望稍后在梯度模式中使用这些计算的输出来进行计算时,启用 no-grad 模式很有用。此上下文管理器可以方便地禁用代码块或函数的梯度,而无需临时将张量设置为 requires_grad=False,然后再改回 True。
例如,在编写优化器时,no-grad 模式可能很有用:在执行训练更新时,您希望就地更新参数,而不希望这次更新被 autograd 记录。您还打算在下一次前向传播中使用更新后的参数进行梯度模式下的计算。
torch.nn.init 中的实现也依赖于 no-grad 模式来初始化参数,以避免在就地更新初始化参数时被 autograd 跟踪。
推理模式#
推理模式是 no-grad 模式的极端版本。与 no-grad 模式一样,推理模式下的计算不会记录在反向图中,但启用推理模式将使 PyTorch 能够进一步加速您的模型。这种更好的运行时带来了一个缺点:在推理模式下创建的张量在退出推理模式后将无法用于要被 autograd 记录的计算中。
当您执行与 autograd 无关的计算,并且不打算在之后将推理模式下创建的张量用于任何要被 autograd 记录的计算时,启用推理模式。
建议您在不需要 autograd 跟踪的代码部分(例如数据处理和模型评估)中尝试推理模式。如果您的用例开箱即用,那就是免费的性能提升。如果您在启用推理模式后遇到错误,请检查您是否在退出推理模式后在被 autograd 记录的计算中使用了在推理模式下创建的张量。如果您在用例中无法避免这种使用,您可以随时切换回 no-grad 模式。
有关推理模式的详细信息,请参阅 推理模式。
有关推理模式的实现细节,请参阅 RFC-0011-InferenceMode。
评估模式(nn.Module.eval())#
评估模式不是一种在本地禁用梯度计算的机制。将它包含在这里是因为它有时被误认为是这样的机制。
从功能上讲,module.eval()(或等效的 module.train(False))与 no-grad 模式和推理模式完全正交。 model.eval() 如何影响您的模型,完全取决于您模型中使用的特定模块以及它们是否定义了任何训练模式特有的行为。
您有责任调用 model.eval() 和 model.train(),如果您的模型依赖于诸如 torch.nn.Dropout 和 torch.nn.BatchNorm2d 这样的模块,它们可能根据训练模式表现不同,例如,以避免在验证数据上更新 BatchNorm 的运行统计数据。
建议您在训练时始终使用 model.train(),在评估模型时(验证/测试)使用 model.eval(),即使您不确定模型是否具有训练模式特有的行为,因为您正在使用的模块可能会更新为在训练和评估模式下表现不同。
Autograd 中的就地操作#
在 autograd 中支持就地操作是一件困难的事情,我们大多数情况下不鼓励使用它们。Autograd 的积极缓冲区释放和重用使其非常高效,很少有情况下就地操作能显著降低内存使用量。除非您面临巨大的内存压力,否则您可能永远不需要使用它们。
限制就地操作适用性的主要有两个原因:
就地操作可能会覆盖计算梯度所需的*值*。
每个就地操作都要求实现重写计算图。非就地版本仅分配新对象并保留对旧图的引用,而就地操作要求更改表示此操作的
Function的所有输入的创建者。这可能很棘手,特别是当许多张量引用同一个存储区时(例如,通过索引或转置创建),并且如果任何其他Tensor引用了被修改输入的存储区,就地函数将引发错误。
就地正确性检查#
每个张量都保存一个版本计数器,每次在任何操作中被标记为“脏”时,该计数器都会递增。当一个 Function 为反向传播保存任何张量时,也会保存其包含的张量的版本计数器。一旦访问 self.saved_tensors,就会对其进行检查,如果它大于保存的值,则会引发错误。这确保了如果您使用的是就地函数且没有看到任何错误,您可以确信计算出的梯度是正确的。
多线程 Autograd#
autograd 引擎负责运行计算反向传播所需的所有反向操作。本节将描述所有细节,以帮助您在多线程环境(仅适用于 PyTorch 1.6+,因为早期版本的行为不同)中充分利用它。
用户可以使用多线程代码(例如 Hogwild 训练)来训练模型,并且不会在并发的反向计算上阻塞,示例如下:
# Define a train function to be used in different threads
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# forward
y = (x + 3) * (x + 4) * 0.5
# backward
y.sum().backward()
# potential optimizer update
# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
请注意用户应该了解的一些行为:
CPU 上的并发#
当您通过 Python 或 C++ API 在多线程上运行 backward() 或 grad() 时,您期望看到额外的并发,而不是在执行期间按特定顺序序列化所有反向调用(PyTorch 1.6 之前的行为)。
非确定性#
如果您从多个线程并发调用 backward() 并共享输入(即 Hogwild CPU 训练),则应预期出现非确定性。这可能是因为参数会自动在线程间共享,因此多个线程可能会在梯度累积期间访问并尝试累积同一个 .grad 属性。这在技术上不安全,可能导致竞态条件,结果可能无效。
开发具有共享参数的多线程模型的用户应牢记线程模型,并应了解上述问题。
可以使用函数式 API torch.autograd.grad() 代替 backward() 来计算梯度,以避免非确定性。
图保留#
如果 autograd 图的一部分在线程间共享(即先在单线程中运行前向传播的第一部分,然后在多个线程中运行第二部分),则图的第一部分是共享的。在这种情况下,在同一图上执行 grad() 或 backward() 的不同线程可能会在其中一个线程上即时销毁图,而另一个线程将崩溃。Autograd 会向用户报错,类似于调用两次 backward() 而不设置 retain_graph=True 的情况,并告知用户他们应该使用 retain_graph=True。
Autograd 节点上的线程安全#
由于 Autograd 允许调用线程驱动其反向执行以实现潜在的并行性,因此确保 CPU 上多线程 backward() 调用的线程安全(这些调用共享部分/全部 GraphTask)非常重要。
自定义 Python autograd.Functions 由于 GIL 而自动具有线程安全性。对于内置的 C++ Autograd 节点(例如 AccumulateGrad、CopySlices)和自定义 autograd::Functions,Autograd 引擎使用线程互斥锁来确保可能具有状态读/写操作的 autograd 节点的线程安全。
C++ 钩子没有线程安全#
Autograd 依赖用户编写线程安全的 C++ 钩子。如果您希望在多线程环境中正确应用钩子,则需要编写适当的线程锁定代码以确保钩子是线程安全的。
复数数的 Autograd#
简而言之
当您使用 PyTorch 对任何具有复数域和/或上域的函数 进行微分时,梯度是在函数是更大的实值损失函数 的一部分的假设下计算的。计算出的梯度是 (注意 z 的共轭),其负值正是梯度下降算法中使用的最陡下降方向。因此,存在一个可行的方法,使得现有的优化器可以与复数参数开箱即用地工作。
此约定与 TensorFlow 的复数微分约定一致,但与 JAX 不同(JAX 计算 ).
如果你有一个内部使用复数运算的实值函数,这里的约定就不重要了:你将始终得到与仅使用实数运算实现时相同的结果。
如果你对数学细节感到好奇,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。
什么是复数导数?#
复数可微的数学定义是将导数的极限定义推广到适用于复数。考虑一个函数 ,
其中 和 是两个变量实值函数,而 是虚数单位。
利用导数的定义,我们可以写出
为了使这个极限存在,不仅 和 必须是实可微的,而且 还必须满足柯西-黎曼 方程。换句话说:针对实部和虚部步长()计算的极限必须相等。这是一个更严格的条件。
复数可微函数通常被称为全纯函数。它们表现良好,具有你在实可微函数中见过的所有优良性质,但在优化领域几乎没有用处。对于优化问题,研究界通常只使用实值目标函数,因为复数不属于任何有序域,因此具有复数值的损失函数意义不大。
事实证明,没有一个有趣实值目标函数能满足柯西-黎曼方程。因此,全纯函数的理论不能用于优化,所以大多数人使用维尔廷格微积分。
维尔廷格微积分登场了……#
因此,我们有了关于复数可微性和全纯函数的这一重要理论,但我们完全无法使用它,因为许多常用的函数都不是全纯的。一个可怜的数学家该怎么办呢?维尔廷格观察到,即使 不是全纯的,也可以将其重写为双变量函数 ,该函数总是全纯的。这是因为 的实部和虚部可以用 和 来表示:
维尔廷格微积分建议研究 而不是,它保证是全纯的,如果 是实可微的(另一种思考方式是坐标变换,从 到 )。这个函数具有关于 和 偏导数,我们可以利用链式法则建立这些偏导数与 的实部和虚部偏导数之间的关系。
从上面的方程,我们得到
这就是您在维基百科上会找到的关于Wirtinger微积分的经典定义。
这种变化带来了许多美妙的推论。
例如,柯西-黎曼方程可以简化为,也就是说,函数可以完全用来表示,而无需引用。
另一个重要(且有些违反直觉)的结果是,正如我们稍后将看到的,当我们在实值损失上进行优化时,在进行变量更新时应采取的步骤由(而不是)给出。
更多阅读,请查阅:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger微积分在优化中有何用处?#
音频和其他领域的研究人员更常使用梯度下降来优化具有复变量的实值损失函数。通常,这些人会将实部和虚部视为可以更新的独立通道。对于步长和损失,我们可以将以下方程写在中:
这些方程在复空间 中如何翻译?
一个非常有趣的事情发生了:Wirtinger微积分告诉我们,可以将上述复变量更新公式简化为仅引用共轭Wirtinger导数 ,这正是我们在优化中采取的步骤。
由于共轭Wirtinger导数给出了实值损失函数的正确步长,PyTorch在对具有实值损失的函数进行微分时,会提供此导数。
PyTorch 如何计算共轭Wirtinger导数?#
通常,我们的导数公式将 grad_output 作为输入,表示我们已经计算出的传入的 Vector-Jacobian 积,即 ,其中 是整个计算(产生实值损失)的损失,而 是我们函数的输出。目标是计算 ,其中 是函数的输入。事实证明,在实值损失的情况下,我们可以仅计算 ,尽管链式法则暗示我们还需要访问 . 如果您想跳过这个推导,请查看本节的最后一个方程,然后跳到下一节。
让我们继续处理 ,定义为 . 如上所述,自动微分的梯度约定以实值损失函数的优化为中心,因此我们假设 是一个更大的实值损失函数 的一部分。使用链式法则,我们可以写出
(1)#
现在使用Wirtinger导数定义,我们可以写出:
需要注意的是,由于 和 是实函数,并且 是实数(因为我们假设 是一个实值函数的一部分),我们有:
(2)#
即, 等于 。
求解上述关于 和 的方程,我们得到:
(3)#
Using (2), we get
(4)#
这个最后的方程对于编写你自己的梯度很重要,因为它将我们的导数公式分解成一个更简单的、可以用手轻松计算的公式。
我如何为复函数编写自己的导数公式?#
上面方框中的方程为所有复函数上的导数提供了通用公式。然而,我们仍然需要计算 和 。有两种方法可以做到这一点。
第一种方法是直接使用 Wirtinger 导数的定义,并计算 和 ,通过使用 和 (这可以用常规方法计算).
第二种方法是使用变量替换技巧,并将 重写为双变量函数 ,并通过将 和 视为独立变量来计算共轭 Wirtinger 导数。这通常更容易;例如,如果所考虑的函数是全纯的,则只会使用 (而 将为零)。
让我们以 作为示例,其中 。
使用第一种方法计算 Wirtinger 导数,我们得到。
使用 (4),以及 grad_output = 1.0(这是 PyTorch 中调用标量输出的 backward() 时使用的默认 grad 输出值),我们得到:
Using the second way to compute Wirtinger derivatives, we directly get
And using (4) again, we get . As you can see, the second way involves lesser calculations, and comes in more handy for faster calculations.
What about cross-domain functions?#
Some functions map from complex inputs to real outputs, or vice versa. These functions form a special case of (4), which we can derive using the chain rule
For , we get
For , we get
Hooks for saved tensors#
您可以通过定义一对 pack_hook / unpack_hook 钩子来控制 如何打包/解包保存的张量。 pack_hook 函数应将其单个参数作为张量,但可以返回任何 Python 对象(例如另一个张量、元组,甚至是包含文件名的字符串)。 unpack_hook 函数将其单个参数作为 pack_hook 的输出,并应返回一个要在反向传播中使用的张量。 unpack_hook 返回的张量只需要与作为 pack_hook 输入的张量具有相同的内容。特别是,可以忽略任何与 autograd 相关的元数据,因为它们将在解包过程中被覆盖。
这样一对钩子示例如下
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
请注意,unpack_hook 不应删除临时文件,因为它可能会被多次调用:只要返回的 SelfDeletingTempFile 对象存在,临时文件就应该保持存在。在上述示例中,我们通过在不再需要临时文件时关闭它(在 SelfDeletingTempFile 对象删除时)来防止临时文件泄漏。
注意
我们保证 pack_hook 只会被调用一次,但 unpack_hook 可以根据反向传播的需要被调用任意次数,并且我们期望它每次都返回相同的数据。
警告
禁止对任一函数的输入执行原地操作,因为这可能会导致意外的副作用。如果修改了 pack 钩子的输入,PyTorch 将抛出错误,但不会捕获 unpack 钩子输入被原地修改的情况。
为已保存的张量注册钩子#
您可以通过在 SavedTensor 对象上调用 register_hooks() 方法来为已保存的张量注册一对钩子。这些对象作为 grad_fn 的属性暴露,并以 _raw_saved_ 前缀开头。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
只要注册了这对钩子,就会调用 pack_hook 方法。每次需要访问已保存的张量时(无论是通过 y.grad_fn._saved_self 还是在反向传播期间),都会调用 unpack_hook 方法。
警告
如果您在已保存的张量被释放后(即调用 backward 之后)仍然维护对 SavedTensor 的引用,则禁止调用其 register_hooks()。PyTorch 大多数情况下会抛出错误,但在某些情况下可能无法做到,并可能出现未定义的行为。
为已保存的张量注册默认钩子#
或者,您可以使用上下文管理器 saved_tensors_hooks 来注册一对钩子,这些钩子将应用于该上下文创建的所有已保存张量。
示例
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x.detach()
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... compute output
output = x
return output
model = Model()
net = nn.DataParallel(model)
用此上下文管理器定义的钩子是线程本地的。因此,以下代码不会产生预期的效果,因为钩子不会经过 DataParallel。
# Example what NOT to do
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
请注意,使用这些钩子会禁用所有原地优化以减少张量对象的创建。例如
with torch.autograd.graph.saved_tensors_hooks(lambda x: x.detach(), lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
没有钩子时,x、y.grad_fn._saved_self 和 y.grad_fn._saved_other 都指向同一个张量对象。使用钩子时,PyTorch 会将 x 打包和解包成两个新的张量对象,它们与原始的 x 共享相同的存储(不执行复制)。
反向传播钩子的执行#
本节将讨论不同钩子何时触发或不触发。然后将讨论它们触发的顺序。将涵盖的钩子有:通过 torch.Tensor.register_hook() 注册到张量的反向钩子,通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到张量的后累积梯度钩子,通过 torch.autograd.graph.Node.register_hook() 注册到节点的后钩子,以及通过 torch.autograd.graph.Node.register_prehook() 注册到节点的预钩子。
特定钩子是否会被触发#
通过 torch.Tensor.register_hook() 注册到张量的钩子在计算该张量的梯度时执行。(注意,这不需要执行张量的 grad_fn。例如,如果张量作为 torch.autograd.grad() 的 inputs 参数的一部分传递,则张量的 grad_fn 可能不会执行,但注册到该张量的钩子将始终执行。)
通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到张量的钩子在其梯度被累积后执行,这意味着张量的 grad 字段已被设置。而通过 torch.Tensor.register_hook() 注册的钩子在梯度计算时运行,而通过 torch.Tensor.register_post_accumulate_grad_hook() 注册的钩子仅在反向传播结束时张量的 grad 字段被 autograd 更新后才触发。因此,后累积梯度钩子只能为叶张量注册。在非叶张量上通过 torch.Tensor.register_post_accumulate_grad_hook() 注册钩子将会报错,即使您调用 backward(retain_graph=True)。
使用 torch.autograd.graph.Node.register_hook() 或 torch.autograd.graph.Node.register_prehook() 注册到 torch.autograd.graph.Node 的钩子仅在该 Node 被执行时才触发。
特定 Node 是否被执行可能取决于反向传播是使用 torch.autograd.grad() 还是 torch.autograd.backward() 调用的。具体来说,当您在对应于作为 inputs 参数传递给 torch.autograd.grad() 或 torch.autograd.backward() 的张量的 Node 上注册钩子时,您应该注意这些差异。
如果您使用 torch.autograd.backward(),上述所有提到的钩子都将被执行,无论您是否指定了 inputs 参数。这是因为 .backward() 会执行所有 Node,即使它们对应于作为输入指定的张量。(请注意,对应于作为 inputs 传递的张量的这个附加 Node 的执行通常是不必要的,但仍然会执行。此行为可能会发生变化;您不应依赖它。)
另一方面,如果您使用 torch.autograd.grad(),注册到对应于传递给 input 的张量的 Nodes 的反向钩子可能不会被执行,因为除非有另一个输入依赖于该 Node 的梯度结果,否则那些 Nodes 不会被执行。
不同钩子触发的顺序#
事件发生的顺序如下:
注册到张量的钩子被执行
注册到 Node 的预钩子被执行(如果 Node 被执行)。
保留
.grad的张量的.grad字段被更新Node 被执行(受上述规则限制)
对于梯度被累积的叶张量,后累积梯度钩子被执行
注册到 Node 的后钩子被执行(如果 Node 被执行)
如果在同一个张量或 Node 上注册了多个相同类型的钩子,它们将按照注册的顺序执行。后执行的钩子可以观察到早期钩子对梯度所做的修改。
特殊钩子#
torch.autograd.graph.register_multi_grad_hook() 是使用注册到张量的钩子实现的。每个单独的张量钩子都按照上面定义的张量钩子顺序触发,当计算完最后一个张量梯度时,会调用注册的多梯度钩子。
torch.nn.modules.module.register_module_full_backward_hook() 是使用注册到 Node 的钩子实现的。在计算前向传播时,钩子会注册到模块输入和输出对应的 grad_fn 上。因为一个模块可能有多个输入和多个输出,所以在前向传播之前应用于模块输入和在返回模块前向传播输出之前应用于模块输出,会首先应用一个虚拟的自定义 autograd Function,以确保这些张量共享一个 grad_fn,我们可以在其上附加我们的钩子。
当张量被原地修改时,张量钩子的行为#
通常,注册到张量的钩子接收输出相对于该张量的梯度,其中张量的值被认为是它在计算反向传播时的值。
然而,如果您向张量注册钩子,然后原地修改该张量,在原地修改之前注册的钩子同样会接收输出相对于该张量的梯度,但张量的值被认为是原地修改之前的值。
如果您更喜欢前一种情况的行为,则应在对张量进行所有原地修改之后再为其注册钩子。例如
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
此外,了解底层机制可能很有帮助:当钩子注册到张量时,它们实际上会永久绑定到该张量的 grad_fn,因此如果该张量随后被原地修改,即使张量现在有了新的 grad_fn,在原地修改之前注册的钩子仍将与旧的 grad_fn 相关联,例如,当 autograd 引擎在图中到达该张量的旧 grad_fn 时,它们将被触发。