自动求导机制#
创建于:2017年1月16日 | 最后更新于:2026年1月6日
本说明将概述自动求导(autograd)的工作原理及其如何记录操作。虽然不必完全理解所有细节,但建议您熟悉这些内容,因为这有助于编写更高效、更整洁的程序,并能辅助调试。
自动求导如何编码历史记录#
自动求导是一个反向自动微分系统。从概念上讲,自动求导会在您执行操作时记录一个包含所有创建数据操作的图,从而为您提供一个有向无环图(DAG),其叶子节点是输入张量,根节点是输出张量。通过从根节点到叶子节点跟踪此图,您可以使用链式法则自动计算梯度。
在内部,自动求导将此图表示为 Function 对象(实际上是表达式)的图,这些对象可以通过 apply() 来计算图评估的结果。在计算前向传播时,自动求导会同时执行所请求的计算并构建一个表示计算梯度函数的图(每个 torch.Tensor 的 .grad_fn 属性都是此图的入口点)。当前向传播完成后,我们会在反向传播中评估此图以计算梯度。
需要注意的一点是,该图在每次迭代时都会从头开始重建,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代时改变图的整体形状和大小。您不必在启动训练之前对所有可能的路径进行编码——您运行什么,就微分什么。
保存的张量#
某些操作需要在前向传播期间保存中间结果,以便执行反向传播。例如,函数 会保存输入 以便计算梯度。
当定义自定义 Python Function 时,您可以使用 save_for_backward() 在前向传播期间保存张量,并在反向传播期间使用 saved_tensors 检索它们。有关详细信息,请参阅 扩展 PyTorch。
对于 PyTorch 定义的操作(例如 torch.pow()),张量会根据需要自动保存。您可以通过查找以 _saved 为前缀的属性,来探索(出于教育或调试目的)某个 grad_fn 保存了哪些张量。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
在前面的代码中,y.grad_fn._saved_self 指向与 x 相同的张量对象。但情况并非总是如此。例如
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
在底层,为了防止引用循环,PyTorch 在保存时已将张量打包(packed),并在读取时将其解包(unpacked)为不同的张量。此处,从访问 y.grad_fn._saved_result 获得的张量是一个与 y 不同的张量对象(但它们仍共享相同的存储空间)。
一个张量是否会被打包成不同的张量对象,取决于它是否为其自身 grad_fn 的输出,这是一个可能发生变化的实现细节,用户不应依赖于此。
您可以使用 保存张量的钩子 (Hooks for saved tensors) 来控制 PyTorch 如何进行打包/解包。
不可微函数的梯度#
使用自动微分进行的梯度计算仅在所使用的每个基本函数均可微分时才有效。遗憾的是,我们在实践中使用的许多函数并不具备此属性(例如 relu 或 sqrt 在 0 处)。为了尽量减少不可微函数的影响,我们通过按顺序应用以下规则来定义基本操作的梯度:
如果函数是可微的,且当前点存在梯度,则使用该梯度。
如果函数是凸函数(至少在局部),则使用最小范数的次梯度(sub-gradient)。
如果函数是凹函数(至少在局部),则使用最小范数的超梯度(super-gradient)(考虑 -f(x) 并应用前一点)。
如果函数已定义,则通过连续性定义当前点的梯度(注意此处可能会出现
inf,例如对于sqrt(0))。如果可能存在多个值,则随意选取一个。如果函数未定义(例如
sqrt(-1)、log(-1)或当输入为NaN时的大多数函数),则用作梯度的值是任意的(我们也可能会抛出错误,但这不能保证)。大多数函数将使用NaN作为梯度,但出于性能原因,某些函数会使用其他值(例如log(-1))。如果函数不是确定性映射(即它不是一个数学函数),它将被标记为不可微。如果将其用于
no_grad环境之外且需要梯度的张量上,则会在反向传播中报错。
自动求导中的除以零#
在 PyTorch 中执行除以零(例如 x / 0)时,前向传播将产生遵循 IEEE-754 浮点运算的 inf 值。虽然可以在计算最终损失之前(例如通过索引或掩码)屏蔽这些 inf 值,但自动求导系统仍会跟踪并对完整的计算图进行微分,包括除以零的操作。
在反向传播期间,这可能会导致出现有问题的梯度表达式。例如
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x / div # Results in [inf, 1]
mask = div != 0 # [False, True]
loss = y[mask].sum()
loss.backward()
print(x.grad) # [nan, 1], not [0, 1]
在此示例中,即使我们只使用掩码后的输出(不包括除以零的部分),自动求导仍会通过完整的计算图计算梯度,包括除以零的操作。这会导致掩码元素出现 nan 梯度,从而可能导致训练不稳定。
为了避免此问题,有几种推荐的方法:
除法前进行掩码处理
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
mask = div != 0
safe = torch.zeros_like(x)
safe[mask] = x[mask] / div[mask]
loss = safe.sum()
loss.backward() # Produces safe gradients [0, 1]
使用 MaskedTensor(实验性 API)
from torch.masked import as_masked_tensor
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x / div
mask = div != 0
loss = as_masked_tensor(y, mask).sum()
loss.backward() # Cleanly handles "undefined" vs "zero" gradients
关键原则是防止除以零操作被记录在计算图中,而不是事后掩码其结果。这确保了自动求导仅通过有效操作计算梯度。
在使用可能产生 inf 或 nan 值的操作时,请务必牢记此行为,因为对输出进行掩码并不能防止计算出有问题的梯度。
局部禁用梯度计算#
Python 提供了几种在本地禁用梯度计算的机制:
要跨整个代码块禁用梯度,可以使用如 no-grad 模式和推理模式(inference mode)的上下文管理器。对于更细粒度地将子图从梯度计算中排除,可以设置张量的 requires_grad 字段。
下面,除了讨论上述机制外,我们还将描述评估模式(nn.Module.eval()),这是一种不用于禁用梯度计算的方法,但由于其名称,经常与前三者混淆。
设置 requires_grad#
requires_grad 是一个标志,默认为 false(*除非包装在* nn.Parameter 中),它允许从梯度计算中细粒度地排除子图。它在前向和反向传播中均生效。
在前向传播期间,只有当至少一个输入张量需要梯度时,操作才会被记录在反向图中。在反向传播(.backward())期间,只有 requires_grad=True 的叶张量才会将梯度累积到其 .grad 字段中。
需要注意的是,即使每个张量都有此标志,设置它也仅对叶张量(没有 grad_fn 的张量,例如 nn.Module 的参数)有意义。非叶张量(确实有 grad_fn 的张量)是与反向图关联的张量。因此,计算需要梯度的叶张量的梯度时,需要它们的梯度作为中间结果。根据此定义,很明显所有非叶张量都将自动具有 require_grad=True。
设置 requires_grad 应该是您控制模型哪些部分参与梯度计算的主要方式,例如,如果您需要在模型微调期间冻结预训练模型的部分内容。
要冻结模型的部分内容,只需对不想更新的参数应用 .requires_grad_(False)。如上所述,由于使用这些参数作为输入的计算不会在前向传播中被记录,因此它们的 .grad 字段在反向传播中不会被更新,因为它们根本不会成为反向图的一部分(这正是我们想要的)。
由于这是一种非常常见的模式,requires_grad 也可以在模块级别通过 nn.Module.requires_grad_() 设置。当应用于模块时,.requires_grad_() 对模块的所有参数(默认情况下具有 requires_grad=True)生效。
梯度模式#
除了设置 requires_grad 外,还可以从 Python 中选择三种梯度模式,这些模式可能会影响 PyTorch 中自动求导在内部处理计算的方式:默认模式(梯度模式)、no-grad 模式和推理模式(inference mode),所有这些都可以通过上下文管理器和装饰器进行切换。
模式 |
排除操作使其不被记录在反向图中 |
跳过额外的自动求导跟踪开销 |
在该模式启用时创建的张量可以在稍后的梯度模式下使用 |
示例 |
|---|---|---|---|---|
默认 |
✓ |
前向传播 |
||
no-grad |
✓ |
✓ |
优化器更新 |
|
推理 (inference) |
✓ |
✓ |
数据处理、模型评估 |
默认模式(梯度模式)#
“默认模式”是我们未启用如 no-grad 和推理模式等其他模式时隐式处于的模式。为了与“no-grad 模式”形成对比,默认模式有时也称为“梯度模式”。
关于默认模式最重要的一点是,它是唯一使 requires_grad 生效的模式。requires_grad 在其他两种模式中始终被覆盖为 False。
No-grad 模式#
No-grad 模式下的计算表现得就像没有任何输入需要梯度一样。换句话说,即使存在 require_grad=True 的输入,no-grad 模式下的计算也永远不会被记录在反向图中。
当您需要执行不应被自动求导记录的操作,但稍后又希望在梯度模式下使用这些计算的输出时,请启用 no-grad 模式。此上下文管理器可以方便地为代码块或函数禁用梯度,而不必临时将张量设置为 requires_grad=False,然后再改回 True。
例如,编写优化器时 no-grad 模式可能很有用:执行训练更新时,您希望原地(in-place)更新参数,而不希望该更新被自动求导记录。您还打算在下一次前向传播的梯度模式中将更新后的参数用于计算。
在初始化参数时,torch.nn.init 中的实现也依赖于 no-grad 模式,以避免在原地更新初始化参数时进行自动求导跟踪。
推理模式(Inference Mode)#
推理模式是 no-grad 模式的极致版本。就像在 no-grad 模式中一样,推理模式下的计算不会被记录在反向图中,但启用推理模式将允许 PyTorch 进一步加速您的模型。这种更好的运行时性能有一个代价:在推理模式下创建的张量将无法用于在退出推理模式后被自动求导记录的计算。
当您执行不与自动求导交互的计算,且不打算在稍后任何会被自动求导记录的计算中使用在推理模式下创建的张量时,请启用推理模式。
建议您在不需要自动求导跟踪的代码部分(例如数据处理和模型评估)尝试推理模式。如果它能直接适用于您的用例,那么这就是免费的性能提升。如果在启用推理模式后遇到错误,请检查您是否在退出推理模式后,将推理模式下创建的张量用于了会被自动求导记录的计算中。如果您的情况无法避免此类使用,您可以随时切换回 no-grad 模式。
有关推理模式的详细信息,请参阅 推理模式 (Inference Mode)。
有关推理模式的实现细节,请参阅 RFC-0011-InferenceMode。
评估模式(nn.Module.eval())#
评估模式不是本地禁用梯度计算的机制。此处将其包括进来是因为它有时被误认为是此类机制。
在功能上,module.eval()(或等效的 module.train(False))与 no-grad 模式和推理模式完全正交。model.eval() 如何影响您的模型,完全取决于模型中使用的特定模块以及它们是否定义了任何特定于训练模式的行为。
如果您的模型依赖于可能根据训练模式表现不同的模块(例如 torch.nn.Dropout 和 torch.nn.BatchNorm2d),您有责任调用 model.eval() 和 model.train(),例如,以避免在验证数据上更新 BatchNorm 的运行统计信息。
建议您在训练时始终使用 model.train(),在评估模型(验证/测试)时始终使用 model.eval(),即使您不确定模型是否具有特定于训练模式的行为,因为您使用的模块可能会被更新以在训练和评估模式下表现不同。
自动求导中的原地操作#
在自动求导中支持原地操作是一个难题,我们在大多数情况下不鼓励使用它们。自动求导激进的缓冲区释放和重用机制使其非常高效,并且原地操作大幅降低内存使用量的情况非常少见。除非您在巨大的内存压力下运行,否则您可能永远不需要使用它们。
限制原地操作适用性的主要有两个原因:
原地操作可能会覆盖计算梯度所需的值。
每个原地操作都需要实现重写计算图。非原地版本只需分配新对象并保留对旧图的引用,而原地操作则需要将所有输入修改为表示此操作的
Function的创建者。这可能很棘手,尤其是当有许多张量引用相同的存储空间时(例如通过索引或转置创建),如果修改输入的存储空间被任何其他Tensor引用,原地函数将报错。
原地正确性检查#
每个张量都保存一个版本计数器,该计数器在每次在任何操作中被标记为脏(dirty)时递增。当一个 Function 保存任何张量用于反向传播时,它们所包含的张量的版本计数器也会被保存。一旦您访问 self.saved_tensors,它就会被检查,如果其值大于保存的值,就会报错。这确保了如果您正在使用原地函数且没有看到任何错误,那么您可以确信计算出的梯度是正确的。
多线程自动求导#
自动求导引擎负责运行计算反向传播所需的所有反向操作。本节将描述所有有助于您在多线程环境中最佳利用它的细节。(这仅适用于 PyTorch 1.6+,因为之前版本的行为不同。)
用户可以使用多线程代码训练他们的模型(例如 Hogwild 训练),并且不会阻塞并发的反向计算,示例代码可能是
# Define a train function to be used in different threads
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# forward
y = (x + 3) * (x + 4) * 0.5
# backward
y.sum().backward()
# potential optimizer update
# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
注意用户应该注意的一些行为
CPU 上的并发性#
当您在 CPU 上通过 Python 或 C++ API 在多个线程中运行 backward() 或 grad() 时,您期望看到额外的并发性,而不是像执行期间那样按特定顺序序列化所有反向调用(PyTorch 1.6 之前的行为)。
非确定性#
如果您并发地从多个线程调用 backward() 并拥有共享输入(例如 Hogwild CPU 训练),则应预料到非确定性。这可能发生是因为参数会在线程之间自动共享,因此,多个线程可能会在梯度累积期间访问并尝试累积相同的 .grad 属性。这在技术上是不安全的,可能会导致竞争条件,并且结果可能无法使用。
开发具有共享参数的多线程模型的用户应牢记线程模型,并应了解上述问题。
可以使用函数式 API torch.autograd.grad() 来计算梯度,而不是 backward(),以避免非确定性。
图的保留#
如果自动求导图的一部分在线程之间共享,例如在单线程中运行前向的第一部分,然后在多线程中运行第二部分,那么图的第一部分是共享的。在这种情况下,不同的线程对同一个图执行 grad() 或 backward() 可能会遇到一个线程在动态销毁图的问题,而另一个线程在这种情况下会崩溃。自动求导会向用户报错,类似于不带 retain_graph=True 调用两次 backward() 的情况,并告知用户应该使用 retain_graph=True。
自动求导节点的线程安全#
由于自动求导允许调用线程驱动其反向执行以实现潜在的并行性,因此确保 CPU 上与共享部分/全部 GraphTask 的并行 backward() 调用保持线程安全非常重要。
自定义 Python autograd.Function 由于 GIL 而自动保持线程安全。对于内置 C++ 自动求导节点(例如 AccumulateGrad, CopySlices)和自定义 autograd::Function,自动求导引擎使用线程互斥锁来确保对可能具有状态读/写的自动求导节点进行线程安全。
C++ 钩子无线程安全#
自动求导依赖于用户编写线程安全的 C++ 钩子。如果您希望钩子在多线程环境中被正确应用,您将需要编写适当的线程锁定代码,以确保钩子是线程安全的。
复数自动求导#
简而言之
当您使用 PyTorch 对任何具有复数域和/或陪域的函数 进行微分时,梯度的计算是基于该函数是更大的实值损失函数 的一部分这一假设。计算出的梯度是 (注意 z 的共轭),其负值正是梯度下降算法中所使用的最速下降方向。因此,存在一条使现有的优化器能够直接用于复数参数的可行路径。
此约定符合 TensorFlow 对复数微分的约定,但与 JAX 不同(JAX 计算的是 )。
如果您有一个内部使用复数运算的实数到实数函数,则此处的约定无关紧要:您总是会得到与完全使用实数运算实现时相同的结果。
如果您对数学细节感到好奇,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。
什么是复数导数?#
复数可微性的数学定义采用了导数的极限定义,并将其推广到复数运算。考虑一个函数 ,
其中 和 是两个变量的实值函数, 是虚数单位。
利用导数的定义,我们可以写出
为了使该极限存在,不仅 和 必须是实可微的,而且 还必须满足柯西-黎曼 方程。换句话说:对于实部和虚部步长()计算出的极限必须相等。这是一个更严格的条件。
复可微函数通常被称为全纯函数。它们表现良好,具有你在实可微函数中所见过的所有优良性质,但在优化领域几乎毫无用处。对于优化问题,研究界只使用实值目标函数,因为复数不属于任何有序域,因此拥有复数值的损失函数并没有多大意义。
此外,事实证明,没有任何有趣的实值目标函数满足柯西-黎曼方程。因此,全纯函数理论不能用于优化,大多数人因此使用 Wirtinger 微积分。
Wirtinger 微积分应运而生……#
所以,我们有了一套关于复可微性和全纯函数的伟大的理论,但我们却完全无法使用它,因为许多常用的函数都不是全纯的。一个可怜的数学家该怎么办?好吧,Wirtinger 观察到,即使 不是全纯的,也可以将其重写为一个双变量函数 ,该函数始终是全纯的。这是因为 的实部和虚部组件可以用 和 来表示:
Wirtinger 微积分建议研究 ,如果 原本是实可微的,这保证了它是全纯的(另一种思考方式是将其视为坐标系变换,从 变为 )。该函数具有偏导数 和 。我们可以利用链式法则建立这些偏导数与关于 的实部和虚部组件的偏导数之间的关系。
根据上述方程,我们得到
这正是你在 维基百科 上可以找到的 Wirtinger 微积分的经典定义。
这一变换有着许多优美的推论。
首先,柯西-黎曼方程可以简单地转化为 (也就是说,函数 可以完全用 表示,而无需涉及 )。
另一个重要(且在某种程度上违反直觉)的结果是,正如我们稍后将看到的,当我们对实值损失函数进行优化时,在更新变量时应采取的步长由 给出(而不是 )。
欲了解更多信息,请查看:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger 微积分在优化中有什么用?#
音频和其他领域的研究人员通常会使用梯度下降法来优化包含复数变量的实值损失函数。通常,这些人将实部和虚部视为可以分别更新的独立通道。对于步长 和损失函数 ,我们可以在 中写出以下方程:
这些方程如何在复数空间 中转换?
发生了一件非常有趣的事情:Wirtinger 微积分告诉我们,我们可以简化上面的复变量更新公式,使其仅涉及共轭 Wirtinger 导数 ,这正是我们在优化中所采取的步骤。
由于共轭 Wirtinger 导数能为实值损失函数提供完全正确的步骤,因此当您对具有实值损失的函数求导时,PyTorch 会为您提供此导数。
PyTorch 是如何计算共轭 Wirtinger 导数的?#
通常,我们的导数公式会将 grad_output 作为输入,它表示我们已经计算出的传入向量-雅可比乘积 (Vector-Jacobian product),即 ,其中 是整个计算过程的损失(产生实值损失),而 是我们函数的输出。这里的目标是计算 ,其中 是函数的输入。事实证明,在实值损失的情况下,我们只需要计算 ,尽管链式法则意味着我们需要访问 。如果您想跳过此推导,请查看本节的最后一个方程,然后跳转到下一节。
让我们继续使用 ,定义为 。如上所述,autograd 的梯度约定以实值损失函数的优化为核心,因此我们假设 是更大的实值损失函数 的一部分。利用链式法则,我们可以写成
(1)#
现在使用 Wirtinger 导数的定义,我们可以写出:
此处需要注意,由于 和 是实函数,并且根据我们假设 是实值函数的一部分从而使得 为实数,我们有:
(2)#
即 等于 。
求解上述方程组中的 和 ,我们得到:
(3)#
利用 (2),我们得到
(4)#
最后一个方程对于编写自己的梯度至关重要,因为它将我们的导数公式分解为一个更简单的公式,易于手动计算。
如何为复杂函数编写自己的导数公式?#
上面的方框方程为复杂函数上的所有导数提供了通用公式。然而,我们仍然需要计算 和 。有两种方法可以实现这一点。
第一种方法是直接利用 Wirtinger 导数的定义,并利用 和 (您可以按常规方式计算)来计算 和 。
第二种方法是利用变量替换技巧,将 重写为二元函数 ,并将 和 视为独立变量来计算共轭 Wirtinger 导数。这通常更容易;例如,如果所讨论的函数是全纯的,则只会用到 (且 将为零)。
让我们以函数 为例,其中 。
使用第一种方法计算 Wirtinger 导数,我们得到:
使用 (4),并令 grad_output = 1.0(这是在 PyTorch 中对标量输出调用 backward() 时使用的默认梯度输出值),我们得到
使用第二种计算 Wirtinger 导数的方法,我们可以直接得到
再次使用 (4),我们得到 。如您所见,第二种方法涉及的计算量较少,在需要快速计算时更为方便。
跨域函数呢?#
有些函数是从复数输入映射到实数输出,反之亦然。这些函数构成了 (4) 的一个特例,我们可以利用链式法则导出它们
对于 ,我们得到
对于 ,我们得到
保存张量的钩子 (Hooks for saved tensors)#
您可以通过定义一对 pack_hook / unpack_hook 钩子来控制 如何打包/解包保存的张量。pack_hook 函数应接收一个张量作为其唯一参数,但可以返回任何 Python 对象(例如另一个张量、元组,甚至是包含文件名的字符串)。unpack_hook 函数接收 pack_hook 的输出作为其唯一参数,并应返回一个用于反向传播的张量。unpack_hook 返回的张量只需与作为输入传递给 pack_hook 的张量内容一致即可。特别是,任何与自动求导(autograd)相关的元数据都可以被忽略,因为它们会在解包过程中被覆盖。
这类钩子对的一个示例如下
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
请注意,unpack_hook 不应删除临时文件,因为它可能会被调用多次:临时文件应在返回的 SelfDeletingTempFile 对象存续期间保持有效。在上述示例中,我们通过在不再需要临时文件时将其关闭(即在 SelfDeletingTempFile 对象被删除时)来防止临时文件泄露。
注意
我们保证 pack_hook 只会被调用一次,但 unpack_hook 可能会根据反向传播的需要被调用多次,并且我们期望它每次都返回相同的数据。
警告
禁止对任何函数的输入执行原地(inplace)操作,因为这可能导致意外的副作用。如果传递给 pack hook 的输入被原地修改,PyTorch 将抛出错误,但它不会捕获传递给 unpack hook 的输入被原地修改的情况。
为保存的张量注册钩子#
你可以通过在 SavedTensor 对象上调用 register_hooks() 方法,为保存的张量注册一对钩子。这些对象作为 grad_fn 的属性公开,并以 _raw_saved_ 为前缀。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
pack_hook 方法在注册该对钩子时立即被调用。unpack_hook 方法在每次需要访问保存的张量时被调用,无论是通过 y.grad_fn._saved_self 还是在反向传播过程中。
警告
如果你在保存的张量被释放后(即调用 backward 后)仍保留对 SavedTensor 的引用,则禁止调用其 register_hooks() 方法。PyTorch 在大多数情况下会抛出错误,但在某些情况下可能会失败,从而导致未定义的行为。
为保存的张量注册默认钩子#
或者,你可以使用上下文管理器 saved_tensors_hooks 来注册一对钩子,这些钩子将应用于在该上下文中创建的所有保存张量。
示例
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x.detach()
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... compute output
output = x
return output
model = Model()
net = nn.DataParallel(model)
通过此上下文管理器定义的钩子是线程本地的。因此,以下代码不会产生预期的效果,因为这些钩子不会通过 DataParallel 传递。
# Example what NOT to do
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
请注意,使用这些钩子会禁用所有旨在减少 Tensor 对象创建的原地优化。例如:
with torch.autograd.graph.saved_tensors_hooks(lambda x: x.detach(), lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
如果不使用钩子,x、y.grad_fn._saved_self 和 y.grad_fn._saved_other 都指向同一个张量对象。使用钩子后,PyTorch 会将 x 打包并解包为两个新的张量对象,它们与原始 x 共享相同的存储空间(不执行拷贝)。
反向钩子的执行#
本节将讨论不同钩子何时触发或不触发,然后讨论它们触发的顺序。将涵盖的钩子包括:通过 torch.Tensor.register_hook() 注册到张量的反向钩子;通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到张量的梯度累加后钩子;通过 torch.autograd.graph.Node.register_hook() 注册到节点的后置钩子;以及通过 torch.autograd.graph.Node.register_prehook() 注册到节点的前置钩子。
特定钩子是否会被触发#
通过 torch.Tensor.register_hook() 注册到张量的钩子会在为该张量计算梯度时执行。(注意,这不需要执行该张量的 grad_fn。例如,如果张量作为 inputs 参数传递给 torch.autograd.grad(),则该张量的 grad_fn 可能不会被执行,但注册到该张量的钩子始终会执行。)
通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到张量的钩子会在该张量的梯度累加完成后执行,这意味着张量的 grad 字段已被设置。而通过 torch.Tensor.register_hook() 注册的钩子是在计算梯度的过程中运行的,通过 torch.Tensor.register_post_accumulate_grad_hook() 注册的钩子仅在反向传播结束时,由自动求导(autograd)更新张量的 grad 字段后触发。因此,梯度累加后钩子只能为叶张量(leaf Tensors)注册。在非叶张量上注册此类钩子将会报错,即使你调用了 backward(retain_graph=True)。
使用 torch.autograd.graph.Node.register_hook() 或 torch.autograd.graph.Node.register_prehook() 注册到 torch.autograd.graph.Node 的钩子,只有在注册的节点被执行时才会触发。
特定节点是否被执行,可能取决于反向传播调用的是 torch.autograd.grad() 还是 torch.autograd.backward()。具体来说,当你在与作为 inputs 参数传递给 torch.autograd.grad() 或 torch.autograd.backward() 的张量对应的节点上注册钩子时,你应该意识到这些差异。
如果你使用的是 torch.autograd.backward(),无论你是否指定了 inputs 参数,上述所有钩子都将被执行。这是因为 .backward() 会执行所有节点,即使它们对应于指定为输入的张量。(注意,执行与作为 inputs 传递的张量对应的额外节点通常是不必要的,但仍然会执行。此行为可能会发生变化;你不应依赖它。)
另一方面,如果你使用的是 torch.autograd.grad(),注册到对应于传递给 input 的张量的节点的反向钩子可能不会被执行,因为除非有另一个输入依赖于该节点的梯度结果,否则这些节点不会被执行。
不同钩子的触发顺序#
发生顺序如下:
执行注册到张量的钩子
执行注册到节点的前置钩子(如果节点被执行)
对于保留了梯度的张量,更新其
.grad字段执行节点(受上述规则约束)
对于已累加
.grad的叶张量,执行梯度累加后钩子执行注册到节点的后置钩子(如果节点被执行)
如果同一个张量或节点上注册了多个相同类型的钩子,它们将按照注册顺序执行。后续执行的钩子可以观察到由先前钩子对梯度所做的修改。
特殊钩子#
torch.autograd.graph.register_multi_grad_hook() 是通过注册到张量的钩子实现的。每个独立的张量钩子遵循上述定义的张量钩子顺序触发,而当最后一个张量梯度计算完成时,注册的多梯度钩子会被调用。
torch.nn.modules.module.register_module_full_backward_hook() 是通过注册到节点的钩子实现的。在前向计算时,钩子被注册到对应于模块输入和输出的 grad_fn 上。由于一个模块可能有多个输入并返回多个输出,因此在进行前向计算之前,一个自定义的虚拟 autograd 函数会被应用于模块的输入,并在前向计算的输出返回前应用于模块的输出,以确保这些张量共享同一个 grad_fn,这样我们就可以将钩子挂载到上面。
当张量被原地修改时张量钩子的行为#
通常,注册到张量的钩子接收输出相对于该张量的梯度,其中该张量的值被视为其在计算反向传播时的值。
然而,如果你为张量注册了钩子,然后对该张量进行了原地修改,那么在原地修改之前注册的钩子同样会接收输出相对于该张量的梯度,但张量的值会被视为其在原地修改之前的值。
如果你偏好前一种情况的行为,你应该在对张量进行所有原地修改后再为其注册钩子。例如:
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
此外,了解以下背景知识可能会有所帮助:当钩子注册到张量时,它们实际上永久绑定到了该张量的 grad_fn 上。因此,如果随后对该张量进行原地修改,即使该张量现在有了新的 grad_fn,在它被原地修改之前注册的钩子仍将与旧的 grad_fn 相关联。例如,当自动求导引擎在图中到达该张量的旧 grad_fn 时,它们就会触发。