Autograd 机制#
创建于: 2017年1月16日 | 最后更新于: 2025年6月16日
本文档将概述 autograd 的工作原理以及如何记录操作。深入理解所有这些细节并非严格必需,但我们推荐您熟悉它们,因为这将有助于您编写更高效、更简洁的程序,并能帮助您进行调试。
Autograd 如何编码历史#
Autograd 是一个反向自动微分系统。从概念上讲,autograd 在您执行操作时会记录一个图,该图记录了创建数据的所有操作,从而为您提供一个有向无环图,其叶子是输入张量,根是输出张量。通过从根到叶跟踪此图,您可以使用链式法则自动计算梯度。
在内部,autograd 将此图表示为一系列 Function
对象(实际上是表达式),这些对象可以通过 apply()
来计算图的求值结果。在计算前向传播时,autograd 会同时执行请求的计算,并构建一个表示计算梯度的函数图(每个 torch.Tensor
的 .grad_fn
属性是进入此图的入口点)。前向传播完成后,我们在反向传播中求值此图以计算梯度。
需要注意的重要一点是,图在每次迭代时都会从头开始重新创建,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代时更改图的整体形状和大小。您不必在启动训练之前编码所有可能的路径——您运行的就是您需要微分的。
保存的张量#
某些操作在反向传播执行时需要保存中间结果。例如,函数 保存输入 以便计算梯度。
在定义自定义 Python Function
时,您可以使用 save_for_backward()
在前向传播过程中保存张量,并使用 saved_tensors
在反向传播过程中检索它们。有关更多信息,请参阅 扩展 PyTorch。
对于 PyTorch 已定义的(例如 torch.pow()
)操作,张量会在需要时自动保存。您可以(出于教育或调试目的)通过查找以 _saved
为前缀的属性来探索某个 grad_fn
保存了哪些张量。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
在前面的代码中,y.grad_fn._saved_self
指向的 Tensor 对象与 x 是同一个。但情况不一定总是如此。例如
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
在底层,为了防止引用循环,PyTorch 在保存张量时将其 *打包*,并在读取时将其 *解包* 到另一个不同的张量中。在此,您通过访问 y.grad_fn._saved_result
获得的张量是不同于 y
的张量对象(但它们仍然共享相同的存储)。
张量是否会被打包成不同的张量对象,取决于它是否是其自身 grad_fn 的输出,这是一个实现细节,可能会发生变化,用户不应依赖它。
您可以通过 已保存张量的钩子 来控制 PyTorch 的打包/解包方式。
不可微函数的梯度#
使用自动微分计算梯度仅在所使用的每个基本函数可微时才有效。不幸的是,我们在实践中使用的许多函数不具备此属性(例如 relu
或 sqrt
在 0
处)。为了尽量减少不可微函数的影响,我们通过按以下顺序应用规则来定义基本操作的梯度:
如果函数可微且在当前点存在梯度,则使用它。
如果函数是凸函数(至少在局部),则使用最小范数的次梯度。
如果函数是凹函数(至少在局部),则使用最小范数的超梯度(考虑 -f(x) 并应用前一点)。
如果函数已定义,则通过连续性定义当前点的梯度(请注意,这里可能出现
inf
,例如对于sqrt(0)
)。如果可能存在多个值,则任意选择一个。如果函数未定义(例如
sqrt(-1)
、log(-1)
或大多数函数在输入为NaN
时),则使用的梯度值为任意值(我们也可能引发错误,但这并非必然)。大多数函数将使用NaN
作为梯度,但出于性能原因,某些函数将使用其他值(例如log(-1)
)。如果函数不是确定性映射(即它不是一个数学函数),它将被标记为不可微。这将在反向传播时导致错误,除非在
no_grad
环境之外使用需要 grad 的张量。
Autograd 中的除零错误#
在 PyTorch 中执行除零(例如 x / 0
)时,前向传播将根据 IEEE-754 浮点算术产生 inf
值。虽然这些 inf
值可以在计算最终损失之前被屏蔽掉(例如通过索引或掩码),但 autograd 系统仍然会跟踪并通过完整的计算图进行微分,包括除零操作。
在反向传播过程中,这可能导致不稳定的梯度表达式。例如
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x / div # Results in [inf, 1]
mask = div != 0 # [False, True]
loss = y[mask].sum()
loss.backward()
print(x.grad) # [nan, 1], not [0, 1]
在此示例中,即使我们只使用掩码输出(它排除了除零),autograd 仍会通过包括除零操作在内的完整计算图来计算梯度。这会导致掩码元素的梯度为 nan
,这可能会导致训练不稳定。
为避免此问题,有几种推荐方法:
在除法前进行掩码
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
mask = div != 0
safe = torch.zeros_like(x)
safe[mask] = x[mask] / div[mask]
loss = safe.sum()
loss.backward() # Produces safe gradients [0, 1]
使用 MaskedTensor(实验性 API)
from torch.masked import as_masked_tensor
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x / div
mask = div != 0
loss = as_masked_tensor(y, mask).sum()
loss.backward() # Cleanly handles "undefined" vs "zero" gradients
关键原则是防止除零操作被记录在计算图中,而不是事后掩盖其结果。这确保了 autograd 只通过有效操作计算梯度。
在处理可能产生 inf
或 nan
值的操作时,牢记此行为非常重要,因为掩盖输出并不能阻止计算有问题的梯度。
局部禁用梯度计算#
Python 提供了几种机制来局部禁用梯度计算:
要禁用整个代码块的梯度,可以使用 no-grad 模式和 inference 模式等上下文管理器。对于从梯度计算中排除子图的更细粒度控制,可以设置张量的 requires_grad
字段。
下面,除了讨论上述机制外,我们还描述了 evaluation 模式(nn.Module.eval()
),此方法不用于禁用梯度计算,但因其名称而经常与这三者混淆。
设置 requires_grad
#
requires_grad
是一个标志,默认为 false,*除非包装在* nn.Parameter
*中*,它允许从梯度计算中细粒度地排除子图。它在前向和反向传播中都生效:
在前向传播期间,只有当至少一个输入张量需要 grad 时,操作才会被记录在反向传播图中。在反向传播(.backward()
)期间,只有 requires_grad=True
的叶子张量才会将梯度累积到其 .grad
字段中。
需要注意的是,虽然每个张量都有此标志,但 *设置* 它仅对叶子张量有意义(即没有 grad_fn
的张量,例如 nn.Module
的参数)。非叶子张量(即具有 grad_fn
的张量)是与之关联了反向传播图的张量。因此,它们需要梯度作为中间结果来计算需要 grad 的叶子张量的梯度。从这个定义可以看出,所有非叶子张量都将自动具有 require_grad=True
。
设置 requires_grad
应该是您控制模型哪些部分参与梯度计算的主要方式,例如,如果您需要在模型微调期间冻结预训练模型的部分。
要冻结模型的部分,只需将您不希望更新的参数应用 .requires_grad_(False)
。如上所述,由于使用这些参数作为输入的计算在前向传播中不会被记录,因此它们不会在反向传播中更新其 .grad
字段,因为它们根本不会成为反向传播图的一部分,正是我们想要的。
由于这是一个非常常见的模式,requires_grad
也可以通过模块级别的 nn.Module.requires_grad_()
来设置。应用于模块时,.requires_grad_()
会影响模块的所有参数(这些参数默认具有 requires_grad=True
)。
Grad 模式#
除了设置 requires_grad
之外,还可以从 Python 中选择三种 grad 模式,它们会影响 autograd 在内部如何处理 PyTorch 中的计算:默认模式(grad 模式)、no-grad 模式和 inference 模式,所有这些都可以通过上下文管理器和装饰器进行切换。
模式 |
从反向传播图中排除操作 |
跳过额外的 autograd 跟踪开销 |
在模式启用时创建的张量稍后可在 grad 模式下使用 |
示例 |
---|---|---|---|---|
默认 |
✓ |
前向传播 |
||
no-grad |
✓ |
✓ |
优化器更新 |
|
inference |
✓ |
✓ |
数据处理、模型评估 |
默认模式(Grad 模式)#
“默认模式”是我们隐式处于的模式,当没有启用其他模式(如 no-grad 和 inference 模式)时。与“no-grad 模式”相对,默认模式有时也称为“grad 模式”。
关于默认模式,最重要的一点是,它是唯一一个 requires_grad
生效的模式。在其他两种模式下,requires_grad
始终被覆盖为 False
。
No-grad 模式#
no-grad 模式下的计算行为就像输入都不需要 grad 一样。换句话说,no-grad 模式下的计算即使有 require_grad=True
的输入,也不会被记录在反向传播图中。
当您需要执行不应被 autograd 记录的操作,但仍希望稍后在 grad 模式下使用这些计算的输出时,可以启用 no-grad 模式。此上下文管理器可以方便地为代码块或函数禁用梯度,而无需临时将张量设置为 requires_grad=False
,然后再恢复为 True
。
例如,在编写优化器时,no-grad 模式可能很有用:在执行训练更新时,您希望就地更新参数,而不让 autograd 记录更新。您还打算在下一个前向传播中使用更新后的参数进行计算。
torch.nn.init
中的实现也依赖于 no-grad 模式,在初始化参数时,以避免在就地更新初始化参数时被 autograd 跟踪。
Inference 模式#
Inference 模式是 no-grad 模式的极端版本。与 no-grad 模式一样,inference 模式下的计算也不会被记录在反向传播图中,但启用 inference 模式可以进一步加快模型的速度。这种更好的运行时性能有一个缺点:在 inference 模式下创建的张量在退出 inference 模式后将无法用于 autograd 记录的计算。
当您执行不与 autograd 交互的计算,并且 *不* 打算稍后在 autograd 记录的任何计算中使用在 inference 模式下创建的张量时,请启用 inference 模式。
我们建议您在不需要 autograd 跟踪的代码部分(例如数据处理和模型评估)尝试 inference 模式。如果它能开箱即用地满足您的用例,那么它就是一个免费的性能提升。如果您在启用 inference 模式后遇到错误,请检查您是否在退出 inference 模式后,在 autograd 记录的计算中使用了在 inference 模式下创建的张量。如果您无法在您的用例中避免此类使用,您可以随时切换回 no-grad 模式。
有关 inference 模式的详细信息,请参阅 Inference Mode。
有关 inference 模式的实现细节,请参阅 RFC-0011-InferenceMode。
评估模式(nn.Module.eval()
)#
评估模式不是一种局部禁用梯度计算的机制。之所以在此提及,是因为它有时会被误认为是这样一种机制。
从功能上讲,module.eval()
(或等效的 module.train(False)
)与 no-grad 模式和 inference 模式是完全正交的。model.eval()
如何影响您的模型完全取决于您模型中使用的具体模块,以及它们是否定义了任何训练模式特定的行为。
如果您使用的模型依赖于 torch.nn.Dropout
和 torch.nn.BatchNorm2d
等模块,它们在训练模式下行为可能不同,例如,为了避免在验证数据上更新 BatchNorm 的运行统计数据,您有责任调用 model.eval()
和 model.train()
。
建议您在训练时始终使用 model.train()
,在评估模型(验证/测试)时使用 model.eval()
,即使您不确定模型是否具有训练模式特定的行为,因为您正在使用的模块可能会更新为在训练和评估模式下表现不同。
Autograd 的就地(in-place)操作#
在 autograd 中支持就地操作是一个棘手的问题,在大多数情况下我们不建议使用它们。Autograd 激进的缓冲区释放和重用使其非常高效,而且很少有情况能够通过就地操作显著降低内存使用量。除非您处于严重的内存压力下,否则您可能永远不需要使用它们。
限制就地操作适用性的两个主要原因:
就地操作可能会覆盖计算梯度所需的值。
每个就地操作都需要实现来重写计算图。原地(out-of-place)版本只是分配新对象并保留对旧图的引用,而就地操作需要将所有输入的创建者更改为表示此操作的
Function
。这可能会很棘手,特别是当有许多张量引用同一存储(例如通过索引或转置创建)时,并且如果任何其他Tensor
引用了已修改输入的存储,就地函数将引发错误。
就地操作正确性检查#
每个张量都维护一个版本计数器,该计数器在每次在任何操作中将其标记为“脏”时递增。当一个 Function 为反向传播保存任何张量时,其包含张量的版本计数器也会被保存。一旦您访问 self.saved_tensors
,它就会被检查,如果它大于保存的值,则会引发错误。这确保了如果您使用就地函数而没有看到任何错误,您可以确信计算出的梯度是正确的。
多线程 Autograd#
autograd 引擎负责运行计算反向传播所需的所有反向操作。本节将详细介绍有助于您在多线程环境中充分利用它的所有细节。(这仅与 PyTorch 1.6+ 相关,因为早期版本的行为不同。)
用户可以使用多线程代码(例如 Hogwild 训练)来训练模型,并且不阻塞并发反向计算,示例如代码:
# Define a train function to be used in different threads
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# forward
y = (x + 3) * (x + 4) * 0.5
# backward
y.sum().backward()
# potential optimizer update
# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
请注意,用户应注意的一些行为:
CPU 上的并发#
当您在 CPU 上的多个线程中通过 Python 或 C++ API 运行 backward()
或 grad()
时,您期望看到额外的并发,而不是在执行期间将所有反向调用串行化(PyTorch 1.6 之前的行为)。
非确定性#
如果您从多个线程并发调用 backward()
,并且有共享输入(即 Hogwild CPU 训练),那么应该会遇到非确定性。这可能发生,因为参数会自动在线程之间共享,因此,多个线程可能在梯度累积期间访问和尝试累加相同的 .grad
属性。这在技术上不安全,并且可能导致竞态条件,导致结果无效。
具有共享参数的多线程模型的用户应牢记线程模型,并理解上述问题。
可以使用函数 API torch.autograd.grad()
来计算梯度,而不是 backward()
,以避免非确定性。
图保留#
如果 autograd 图的一部分在线程之间共享,即在单线程中运行前向传播的第一部分,然后在多线程中运行第二部分,那么图的第一部分就是共享的。在这种情况下,不同线程在同一图上执行 grad()
或 backward()
可能会遇到一个线程在运行时破坏图的问题,而另一个线程在这种情况下会崩溃。Autograd 将像调用 backward()
两次而没有 retain_graph=True
一样向用户报错,并告知用户他们应该使用 retain_graph=True
。
Autograd 节点上的线程安全#
由于 Autograd 允许调用线程驱动其反向执行以实现潜在的并行化,因此确保 CPU 上的线程安全至关重要,即使存在共享图任务的并行 backward()
调用。
自定义 Python autograd.Function
由于 GIL 的存在,自动是线程安全的。对于内置 C++ Autograd 节点(例如 AccumulateGrad、CopySlices)和自定义 autograd::Function
,Autograd 引擎使用线程互斥锁来确保对可能具有状态读写操作的 Autograd 节点的线程安全。
C++ 钩子上的线程不安全#
Autograd 依赖用户编写线程安全的 C++ 钩子。如果您希望钩子在多线程环境中正确应用,您需要编写适当的线程锁定代码来确保钩子是线程安全的。
复数 Autograd#
简而言之:
当您使用 PyTorch 对任何具有复数域/值域的函数 进行微分时,梯度是在假定该函数是更大的实值损失函数 的一部分的情况下计算的。计算出的梯度是 (请注意 z 的共轭),其负值正是梯度下降算法使用的最陡下降方向。因此,有一个可行的方法可以使现有的优化器直接与复数参数一起工作。
此约定与 TensorFlow 的复数微分约定匹配,但与 JAX 不同(JAX 计算 )。
如果您有一个内部使用复数运算的实值到实值的函数,那么这里的约定并不重要:您将始终获得与使用仅实数运算实现它相同的结果。
如果您对数学细节感到好奇,或者想了解如何在 PyTorch 中定义复数导数,请继续阅读。
什么是复数导数?#
复数可微性的数学定义采用导数的极限定义并将其推广到对复数进行运算。考虑一个函数 ,
其中 和 是两个变量实值函数, 是虚数单位。
使用导数定义,我们可以写出:
为了使此极限存在,不仅 和 必须是实可微的,而且 还必须满足柯西-黎曼 方程。换句话说:对于实部和虚部步骤()计算的极限必须相等。这是一个更严格的条件。
复数可微函数通常被称为全纯函数。它们行为良好,具有您从实数可微函数中看到的所有良好特性,但在优化世界中几乎没有用处。对于优化问题,研究界只使用实值目标函数,因为复数不属于任何有序域,所以具有复数值损失意义不大。
事实证明,没有有趣的实值目标会满足柯西-黎曼方程。因此,全纯函数的理论不能用于优化,因此大多数人使用维尔丁格微积分。
维尔丁格微积分登场……#
因此,我们拥有了关于复数可微性和全纯函数的伟大理论,但由于许多常用函数不是全纯的,我们几乎无法使用它。可怜的数学家该怎么办?维尔丁格观察到,即使 不是全纯的,也可以将其重写为二元函数 ,该函数总是全纯的。这是因为 的实部和虚部可以表示为 和 的形式:
该函数具有偏导数 和 。我们可以使用链式法则在这些偏导数与 的实部和虚部偏导数之间建立关系。
从以上方程,我们得到
其中“j”是虚数单位。这是你在维基百科上找到的经典的Wirtinger微积分定义。
这个变化带来许多优美的结果。
例如,柯西-黎曼方程可以简化为 (也就是说,函数 可以完全用 来表示,而不需要引用 。
另一个重要(且有些违反直觉)的结果是,正如我们稍后将看到的,当我们对实值损失进行优化时,变量更新时应采取的步骤由 (而不是 )。
更多阅读,请参阅:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger微积分在优化中有什么用?#
音频和其他领域的 शोध者,更常见的是使用梯度下降来优化具有复数的实值损失函数。通常,这些人将实部和虚部分别视为可以更新的通道。对于步长 和损失 ,我们可以写出 中的以下方程:
这些方程如何转换为复数空间 呢?
发生了非常有趣的事情:Wirtinger微积分告诉我们,我们可以将上面复数变量更新公式简化为仅引用共轭Wirtinger导数 ,为我们提供了优化中所需的恰当步长。
由于共轭Wirtinger导数提供了实值损失函数的正确步长,PyTorch 在对具有实值损失的函数进行微分时,会返回此导数。
PyTorch 如何计算共轭Wirtinger 导数?#
通常,我们的导数公式将 grad_output 作为输入,代表已计算的传入向量-雅可比行列式积,也称为 ,其中 是整个计算的损失(产生实值损失), 是我们函数的输出。这里的目标是计算 ,其中 是函数的输入。结果表明,在实值损失的情况下,我们可以仅计算 ,尽管链式法则暗示我们也需要 的访问权限。如果想跳过此推导,请查看本节的最后一个方程,然后跳到下一节。
我们继续处理 ,其定义为 。如上所述,autograd 的梯度约定侧重于实值损失函数的优化,因此我们假设 是更大的实值损失函数 的一部分。使用链式法则,我们可以写出:
(1)#
现在根据Wirtinger导数的定义,我们可以写出:
需要注意的是,由于 和 是实函数,并且根据我们对 是实值函数一部分的假设, 也是实数,我们有
(2)#
即, 等于 。
求解上述关于 和 的方程,我们得到
(3)#
使用 (2),我们得到
(4)#
最后一个方程对于编写您自己的梯度很重要,因为它将我们的导数公式分解为更简单、易于手工计算的公式。
我如何编写自己的复杂函数导数公式?#
上述带框的方程为所有复杂函数导数提供了通用公式。然而,我们仍然需要计算 和 。
第一种方法是直接使用 Wirtinger 导数的定义,计算 和 使用 和 (这可以通过常规方式计算)。
第二种方法是使用变量替换技巧,将 重写为双变量函数 ,并通过将 和 视为独立变量来计算共轭 Wirtinger 导数。这通常更容易;例如,如果所讨论的函数是全纯的,则只会使用 (并且 将为零)。
让我们以 作为示例,其中 。
使用第一种方法计算 Wirtinger 导数,我们得到。
根据 (4) 和 grad_output = 1.0(在 PyTorch 中调用 backward()
时标量输出使用的默认 grad 输出值)计算,我们得到
使用第二种方法计算 Wirtinger 导数,我们直接得到
再根据 (4),我们得到 。正如你所见,第二种方法计算量更少,而且对于更快的计算来说更方便。
已保存张量的钩子#
您可以通过定义一对 pack_hook
/ unpack_hook
钩子来控制 已保存张量的打包/解包方式。 pack_hook
函数应接受一个张量作为其唯一参数,但可以返回任何 Python 对象(例如,另一个张量、一个元组,甚至一个包含文件名的字符串)。 unpack_hook
函数将 pack_hook
的输出作为其唯一参数,并应返回一个用于反向传播的张量。 unpack_hook
返回的张量只需要具有与传递给 pack_hook
的张量相同的内容。特别是,任何与 autograd 相关的元数据都可以被忽略,因为它们将在解包过程中被覆盖。
这样一对钩子中的一个例子是
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
请注意,unpack_hook
不应删除临时文件,因为它可能会被多次调用:只要返回的 SelfDeletingTempFile 对象存在,临时文件就应该保持有效。在上面的例子中,我们通过在 SelfDeletingTempFile 对象被删除时关闭它来防止临时文件泄露(当 SelfDeletingTempFile 对象被删除时)。
注意
我们保证 pack_hook
只会被调用一次,而 unpack_hook
可以在反向传播所需的任何时候被调用,并且我们期望它每次都返回相同的数据。
警告
禁止对任何函数的输入执行原地操作,因为这可能导致意外的副作用。如果输入到 pack hook 的张量被原地修改,PyTorch 将抛出错误,但它不会捕获 unpack hook 的输入被原地修改的情况。
为已保存张量注册钩子#
您可以通过在 SavedTensor
对象上调用 register_hooks()
方法来为已保存张量注册一对钩子。这些对象作为 grad_fn
的属性公开,并以 _raw_saved_
前缀开头。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
一旦注册了这对钩子,pack_hook
方法就会被调用。每次需要访问已保存张量时(无论是通过 y.grad_fn._saved_self
还是在反向传播过程中),unpack_hook
方法都会被调用。
警告
如果在已保存张量被释放后(即反向传播调用后)仍然保留对 SavedTensor
的引用,则禁止调用其 register_hooks()
。PyTorch 大多数情况下会抛出错误,但在某些情况下可能无法做到,并可能导致未定义行为。
为已保存张量注册默认钩子#
或者,您可以使用上下文管理器 saved_tensors_hooks
来注册一对钩子,这些钩子将应用于该上下文中创建的 *所有* 已保存张量。
示例
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x.detach()
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... compute output
output = x
return output
model = Model()
net = nn.DataParallel(model)
使用此上下文管理器定义的钩子是线程局部的。因此,以下代码将不会产生预期的效果,因为钩子不会通过 DataParallel。
# Example what NOT to do
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
请注意,使用这些钩子会禁用所有为了减少 Tensor 对象创建而进行的原地优化。例如
with torch.autograd.graph.saved_tensors_hooks(lambda x: x.detach(), lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
在没有钩子的情况下,x
、y.grad_fn._saved_self
和 y.grad_fn._saved_other
都引用同一个张量对象。使用钩子后,PyTorch 会将 x 打包和解包到两个新的张量对象中,这些对象与原始 x 共享相同的存储(不执行复制)。
反向钩子的执行#
本节将讨论不同钩子何时触发或不触发。然后将讨论它们触发的顺序。将涵盖的钩子有:通过 torch.Tensor.register_hook()
注册到 Tensor 的反向钩子,通过 torch.Tensor.register_post_accumulate_grad_hook()
注册到 Tensor 的累积梯度后钩子,通过 torch.autograd.graph.Node.register_hook()
注册到 Node 的后钩子,以及通过 torch.autograd.graph.Node.register_prehook()
注册到 Node 的前钩子。
某个钩子是否会被触发#
通过 torch.Tensor.register_hook()
注册到 Tensor 的钩子在计算该 Tensor 的梯度时执行。(请注意,这不需要执行 Tensor 的 grad_fn。例如,如果 Tensor 作为 torch.autograd.grad()
的 inputs
参数的一部分传递,则 Tensor 的 grad_fn 可能不会执行,但注册到该 Tensor 的钩子将始终执行。)
通过 torch.Tensor.register_post_accumulate_grad_hook()
注册到 Tensor 的钩子在计算该 Tensor 的梯度并设置了 Tensor 的 grad 字段后执行。而通过 torch.Tensor.register_hook()
注册的钩子在计算梯度时运行,通过 torch.Tensor.register_post_accumulate_grad_hook()
注册的钩子仅在 Tensor 的 grad 字段在反向传播结束时被 autograd 更新后触发。因此,累积梯度后钩子只能为叶子 Tensor 注册。在非叶子 Tensor 上通过 torch.Tensor.register_post_accumulate_grad_hook()
注册钩子将会报错,即使您调用了 backward(retain_graph=True)。
使用 torch.autograd.graph.Node.register_hook()
或 torch.autograd.graph.Node.register_prehook()
注册到 torch.autograd.graph.Node
的钩子仅在注册了它的 Node 被执行时才会触发。
某个 Node 是否被执行可能取决于反向传播是否使用了 torch.autograd.grad()
或 torch.autograd.backward()
调用。具体来说,当您在传递给 torch.autograd.grad()
或 torch.autograd.backward()
的 Tensor 的 inputs
参数的一部分的 Node 上注册钩子时,您应该注意这些区别。
如果您使用的是 torch.autograd.backward()
,则所有上述钩子都将被执行,无论您是否指定了 inputs
参数。这是因为 .backward() 会执行所有 Node,即使它们对应于作为输入指定的 Tensor。(请注意,此额外 Node 的执行对应于传递给 inputs
的 Tensor 通常是不必要的,但仍然执行。此行为可能会发生更改;您不应依赖它。)
另一方面,如果您使用的是 torch.autograd.grad()
,则传递给 input
的 Tensor 对应的 Node 上的反向钩子可能不会被执行,因为除非有另一个输入依赖于此 Node 的梯度结果,否则这些 Node 不会被执行。
不同钩子触发的顺序#
事件发生的顺序是
注册到 Tensor 的钩子被执行
注册到 Node 的前钩子被执行(如果 Node 被执行)。
具有
retain_grad
的 Tensor 的.grad
字段被更新Node 被执行(受上述规则约束)
对于累积了
.grad
的叶子 Tensor,将执行累积梯度后钩子注册到 Node 的后钩子被执行(如果 Node 被执行)
如果多个相同类型的钩子被注册到同一个 Tensor 或 Node 上,它们将按照注册的顺序执行。稍后执行的钩子可以观察到早期钩子对梯度的修改。
特殊钩子#
torch.autograd.graph.register_multi_grad_hook()
是通过注册到 Tensor 的钩子实现的。每个 Tensor 钩子都遵循上述 Tensor 钩子的顺序触发,当最后一个 Tensor 梯度计算完成后,已注册的多梯度钩子将被调用。
torch.nn.modules.module.register_module_full_backward_hook()
是通过注册到 Node 的钩子实现的。在计算前向传播时,钩子会被注册到与模块输入和输出对应的 grad_fn 上。由于一个模块可能接受多个输入并返回多个输出,因此在模块前向传播之前,会先在模块的输入上应用一个虚拟的自定义 autograd Function,在模块输出之后,在返回模块输出之前,以确保这些 Tensor 共享一个 grad_fn,然后我们就可以将钩子附加到它上面了。
当 Tensor 被原地修改时 Tensor 钩子的行为#
通常,注册到 Tensor 的钩子接收到输出相对于该 Tensor 的梯度,其中 Tensor 的值被认为是其在计算反向传播时该 Tensor 的值。
然而,如果您向 Tensor 注册了钩子,然后原地修改了该 Tensor,那么在原地修改之前注册的钩子将类似地接收输出相对于该 Tensor 的梯度,但 Tensor 的值将被视为其在原地修改之前的值。
如果您更喜欢前一种情况下的行为,您应该在对 Tensor 进行所有原地修改之后再注册钩子。例如
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
此外,值得注意的是,在底层,当钩子注册到 Tensor 时,它们实际上会永久绑定到该 Tensor 的 grad_fn,因此如果该 Tensor 随后被原地修改,即使 Tensor 现在有了新的 grad_fn,在原地修改之前注册的钩子仍将与旧的 grad_fn 相关联,例如,当 autograd 引擎在图中到达该 Tensor 的旧 grad_fn 时,它们会触发。