评价此页

Autograd 机制#

创建于: 2017年1月16日 | 最后更新于: 2025年6月16日

本文档将概述 autograd 的工作原理和操作记录方式。深入理解这些细节并非强制要求,但我们推荐您熟悉它们,因为这将帮助您编写更高效、更简洁的程序,并有助于调试。

Autograd 如何编码历史#

Autograd 是一个反向自动微分系统。从概念上讲,autograd 在执行操作时,会记录所有创建数据的操作,形成一个有向无环图(DAG),其叶节点是输入张量,根节点是输出张量。通过追踪这个图从根到叶的路径,您可以使用链式法则自动计算梯度。

在内部,autograd 将这个图表示为一系列 Function 对象(实际上是表达式)组成的图,这些对象可以被 apply() 来计算图的求值结果。在进行前向传播时,autograd 会同时执行请求的计算,并构建一个表示计算梯度的函数图(每个 torch.Tensor.grad_fn 属性是进入此图的入口)。当前向传播完成后,我们在后向传播中评估这个图来计算梯度。

需要注意的是,这个图在每次迭代时都会从头开始重建,这正是它允许使用任意 Python 控制流语句(这些语句可以在每次迭代中改变图的整体形状和大小)的原因。您不必在启动训练前就编码所有可能的路径——您运行什么,就对什么进行微分。

保存的张量#

某些操作在执行后向传播时,需要在前向传播过程中保存中间结果。例如,函数 xx2x\mapsto x^2 保存输入 xx 以计算梯度。

在定义自定义 Python Function 时,您可以使用 save_for_backward() 在前向传播时保存张量,并在后向传播时使用 saved_tensors 检索它们。更多信息请参阅 扩展 PyTorch

对于 PyTorch 定义的操作(例如 torch.pow()),张量会根据需要自动保存。您可以(出于教育或调试目的)通过查找以 _saved 为前缀的属性来探索特定 grad_fn 保存了哪些张量。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self))  # True
print(x is y.grad_fn._saved_self)  # True

在之前的代码中,y.grad_fn._saved_self 指向与 x 相同的 Tensor 对象。但这并非总是如此。例如:

x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result))  # True
print(y is y.grad_fn._saved_result)  # False

在底层,为了防止引用循环,PyTorch 在保存张量时对其进行了*打包*,并在读取时将其*解包*到一个不同的张量中。在这里,您通过访问 y.grad_fn._saved_result 获得的张量对象与 y 是不同的张量对象(但它们仍然共享相同的存储)。

张量是否会被打包成不同的张量对象,取决于它是否是其自身 grad_fn 的输出,这是一个实现细节,可能会发生更改,用户不应依赖它。

您可以通过 保存张量的钩子 来控制 PyTorch 的打包/解包行为。

不可微函数的梯度#

自动微分的梯度计算仅在所使用的每个基本函数可微时才有效。不幸的是,实践中使用的许多函数不具备此属性(例如 0 处的 relusqrt)。为了尽量减少不可微函数的影响,我们通过以下规则顺序定义基本操作的梯度:

  1. 如果函数在该点可微,则使用该点的梯度。

  2. 如果函数是凸函数(至少在局部),则使用最小范数的次梯度。

  3. 如果函数是凹函数(至少在局部),则使用最小范数的超梯度(考虑 -f(x) 并应用上一条)。

  4. 如果函数在该点有定义,则通过连续性定义该点的梯度(注意这里可能出现 inf,例如对于 sqrt(0))。如果存在多个可能的值,则任意选择一个。

  5. 如果函数未定义(例如 sqrt(-1)log(-1) 或输入为 NaN 时的大多数函数),则使用的梯度值是任意的(我们也可能抛出错误,但并非保证)。大多数函数将使用 NaN 作为梯度,但出于性能原因,某些函数将使用其他值(例如 log(-1))。

  6. 如果函数不是确定性映射(即它不是一个数学函数),它将被标记为不可微。这将在后向传播中导致错误(如果在 no_grad 环境外用于需要 grad 的张量)。

Autograd 中的除零错误#

在 PyTorch 中进行除零运算(例如 x / 0)时,前向传播将遵循 IEEE-754 浮点算术生成 inf 值。虽然这些 inf 值可以在计算最终损失之前被屏蔽掉(例如通过索引或掩码),但 autograd 系统仍然会追踪并对整个计算图进行微分,包括除零运算。

在反向传播过程中,这可能导致梯度表达式出现问题。例如:

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])

y = x / div          # Results in [inf, 1]
mask = div != 0      # [False, True]
loss = y[mask].sum()
loss.backward()
print(x.grad)        # [nan, 1], not [0, 1]

在此示例中,即使我们仅使用了屏蔽后的输出(其中排除了除零运算),autograd 仍然通过完整的计算图(包括除零运算)来计算梯度。这会导致屏蔽元素的梯度为 nan,从而可能导致训练不稳定。

为了避免此问题,有几种推荐的方法:

  1. 在除法之前屏蔽

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])

mask = div != 0
safe = torch.zeros_like(x)
safe[mask] = x[mask] / div[mask]
loss = safe.sum()
loss.backward()      # Produces safe gradients [0, 1]
  1. 使用 MaskedTensor(实验性 API)

from torch.masked import as_masked_tensor

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])

y = x / div
mask = div != 0
loss = as_masked_tensor(y, mask).sum()
loss.backward()      # Cleanly handles "undefined" vs "zero" gradients

关键原则是阻止除零运算被记录在计算图中,而不是事后屏蔽其结果。这确保了 autograd 只会计算有效运算的梯度。

当处理可能产生 infnan 值的操作时,这一点很重要,因为屏蔽输出并不能阻止产生有问题的梯度。

局部禁用梯度计算#

Python 提供了几种机制来局部禁用梯度计算:

要禁用代码块的整体梯度,可以使用 no-grad 模式和 inference 模式等上下文管理器。要更精细地排除子图的梯度计算,可以设置张量的 requires_grad 字段。

下面,除了讨论上述机制外,我们还将描述 evaluation 模式(nn.Module.eval()),该方法不用于禁用梯度计算,但由于其名称,经常与这三种模式混淆。

设置 requires_grad#

requires_grad 是一个标志,默认为 false,*除非包装在* nn.Parameter *中*,它允许对梯度计算进行子图的精细排除。它在前向和后向传播中都生效。

在前向传播过程中,只有当至少一个输入张量需要 grad 时,操作才会被记录在后向图中。在后向传播(.backward())过程中,只有 requires_grad=True 的叶张量才会在其 .grad 字段中累积梯度。

需要注意的是,尽管每个张量都有此标志,但*设置*它仅对叶张量(即没有 grad_fn 的张量,例如 nn.Module 的参数)有意义。非叶张量(即有 grad_fn 的张量)是与它们相关的后向图的张量。因此,它们的梯度将作为中间结果,以计算需要 grad 的叶张量的梯度。从这个定义可以看出,所有非叶张量都将自动具有 requires_grad=True

设置 requires_grad 应该是您控制模型哪些部分参与梯度计算的主要方式,例如,如果您需要在模型微调期间冻结预训练模型的某些部分。

要冻结模型的一部分,只需对您不希望更新的参数应用 .requires_grad_(False)。如上所述,由于使用这些参数作为输入的计算在前向传播中不会被记录,因此它们在后向传播中不会更新其 .grad 字段,因为它们根本不会成为后向图的一部分,这正是所期望的。

由于这是一个非常常见的模式,requires_grad 也可以在模块级别使用 nn.Module.requires_grad_() 进行设置。当应用于模块时,.requires_grad_() 会影响该模块的所有参数(这些参数默认具有 requires_grad=True)。

Grad 模式#

除了设置 requires_grad 之外,还可以从 Python 中选择三种 grad 模式,它们可以影响 autograd 在 PyTorch 中如何处理计算:默认模式(grad 模式)、no-grad 模式和 inference 模式,这些都可以通过上下文管理器和装饰器进行切换。

模式

排除操作不被记录在后向图中

跳过额外的 autograd 跟踪开销

在模式启用时创建的张量稍后可以在 grad 模式下使用

示例

默认

前向传播

no-grad

优化器更新

inference

数据处理、模型评估

默认模式(Grad 模式)#

“默认模式”是指当我们没有启用 no-gradinference 模式时所处的模式。与“no-grad 模式”相对,“默认模式”有时也称为“grad 模式”。

关于默认模式最重要的一点是,它是唯一一种 requires_grad 生效的模式。在另外两种模式下,requires_grad 总是被重写为 False

no-grad 模式#

no-grad 模式下的计算,其行为就好像没有输入需要 grad 一样。换句话说,即使输入具有 require_grad=True,在 no-grad 模式下的计算也永远不会被记录在后向图中。

当您需要执行不应被 autograd 记录的操作,但仍希望稍后在 grad 模式下使用这些计算的输出时,可以使用 no-grad 模式。此上下文管理器方便您为代码块或函数禁用梯度,而无需临时将张量设置为 requires_grad=False,然后再改回 True

例如,在编写优化器时,no-grad 模式可能很有用:在执行训练更新时,您希望就地更新参数,而不希望此次更新被 autograd 记录。您还打算在下一个前向传播中使用更新后的参数进行计算。

torch.nn.init 中的实现也依赖于 no-grad 模式来初始化参数,以避免在就地更新初始化参数时被 autograd 跟踪。

推理模式#

推理模式是 no-grad 模式的极端版本。与 no-grad 模式一样,推理模式下的计算不会被记录在后向图中,但启用推理模式可以使 PyTorch 更快地加速您的模型。这种更好的运行时带来了缺点:在推理模式下创建的张量在退出推理模式后将无法用于 autograd 记录的计算。

当您执行与 autograd 没有交互的计算,并且*不*打算稍后在 autograd 记录的任何计算中使用在推理模式下创建的张量时,启用推理模式。

建议您尝试在不需要 autograd 跟踪的代码部分(例如数据处理和模型评估)中使用推理模式。如果它开箱即用,那将是免费的性能提升。如果您在启用推理模式后遇到错误,请检查您是否在退出推理模式后,在 autograd 记录的计算中使用了在推理模式下创建的张量。如果您的场景无法避免此类使用,您可以随时切换回 no-grad 模式。

有关推理模式的详细信息,请参阅 推理模式

有关推理模式实现细节,请参阅 RFC-0011-InferenceMode

评估模式(nn.Module.eval()#

评估模式不是一种局部禁用梯度计算的机制。它之所以在此处包含,是因为它有时会被误认为是这样的机制。

从功能上看,module.eval()(或等效的 module.train(False))与 no-grad 模式和 inference 模式完全正交。 model.eval() 如何影响您的模型完全取决于模型中使用的特定模块,以及它们是否定义了任何特定于训练模式的行为。

您负责在模型依赖于 torch.nn.Dropouttorch.nn.BatchNorm2d 等模块时调用 model.eval()model.train(),这些模块在训练模式下可能表现不同,例如,为了避免在验证数据上更新 BatchNorm 的运行统计信息。

建议您在训练时始终使用 model.train(),在评估模型(验证/测试)时使用 model.eval(),即使您不确定模型是否具有特定于训练模式的行为,因为您使用的模块可能会更新为在训练和评估模式下具有不同的行为。

Autograd 的就地操作#

在 autograd 中支持就地操作是一个棘手的问题,我们不鼓励在大多数情况下使用它们。Autograd 的激进缓冲区释放和重用使其非常高效,很少有情况可以使就地操作显著降低内存使用量。除非您面临巨大的内存压力,否则您可能永远不需要使用它们。

限制就地操作适用性的主要原因有两个:

  1. 就地操作可能会覆盖计算梯度所需的值。

  2. 每个就地操作都需要实现来重写计算图。原地操作只是分配新对象并保留对旧图的引用,而就地操作需要将所有输入的创建者更改为代表此操作的 Function。这可能很棘手,特别是当有许多张量引用同一存储(例如通过索引或转置创建)时。如果修改输入的存储被任何其他 Tensor 引用,就地函数将引发错误。

就地正确性检查#

每个张量都维护一个版本计数器,该计数器在每次在任何操作中被标记为脏时递增。当 Function 为后向传播保存任何张量时,还会保存其包含张量的版本计数器。一旦您访问 self.saved_tensors,就会进行检查,如果大于保存的值,则会引发错误。这确保了如果您使用就地函数且未看到任何错误,您可以确信计算出的梯度是正确的。

多线程 Autograd#

autograd 引擎负责运行计算后向传播所需的所有后向操作。本节将详细介绍所有有助于您在多线程环境中充分利用它的细节。(这仅适用于 PyTorch 1.6+,因为早期版本的行为有所不同。)

用户可以使用多线程代码(例如 Hogwild 训练)来训练他们的模型,并且不会阻塞并发的后向计算。示例如下:

# Define a train function to be used in different threads
def train_fn():
    x = torch.ones(5, 5, requires_grad=True)
    # forward
    y = (x + 3) * (x + 4) * 0.5
    # backward
    y.sum().backward()
    # potential optimizer update


# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
    p = threading.Thread(target=train_fn, args=())
    p.start()
    threads.append(p)

for p in threads:
    p.join()

请注意,用户应该意识到一些行为:

CPU 并发#

当您通过 Python 或 C++ API 在 CPU 上的多个线程中运行 backward()grad() 时,您期望看到额外的并发,而不是在执行期间将所有后向调用串行化(PyTorch 1.6 之前的行为)。

非确定性#

如果您从多个线程并发调用 backward() 并且有共享输入(例如 Hogwild CPU 训练),那么应该预期非确定性。这可能会发生,因为参数会自动在线程之间共享,因此多个线程可能在梯度累积期间访问并尝试累积相同的 .grad 属性。这在技术上是不安全的,它可能导致竞态条件,结果可能无法使用。

开发具有共享参数的多线程模型的用户应牢记线程模型,并理解上述问题。

可以使用函数式 API torch.autograd.grad() 来计算梯度,而不是 backward(),以避免非确定性。

图保留#

如果 autograd 图的一部分在线程之间共享(即,单线程运行前向图的第一部分,然后在多线程中运行第二部分),则图的第一部分被共享。在这种情况下,不同线程在同一图上执行 grad()backward() 可能会在其中一个线程的执行过程中销毁图,导致另一个线程崩溃。Autograd 会像两次调用 backward() 而没有 retain_graph=True 一样向用户抛出错误,并告知用户他们应该使用 retain_graph=True

Autograd 节点上的线程安全#

由于 Autograd 允许调用线程驱动其后向执行以实现潜在的并行化,因此确保 CPU 上的线程安全非常重要,特别是在共享部分/全部 GraphTask 的并行 backward() 调用时。

自定义 Python autograd.Function 由于 GIL 的存在,会自动成为线程安全的。对于内置的 C++ Autograd 节点(例如 AccumulateGrad, CopySlices)和自定义 autograd::Function,Autograd 引擎使用线程互斥锁来确保可能具有状态读写操作的 autograd 节点上的线程安全。

C++ 钩子上不存在线程安全#

Autograd 依赖用户编写线程安全的 C++ 钩子。如果您希望钩子在多线程环境中正确应用,您需要编写适当的线程锁定代码来确保钩子是线程安全的。

复数 Autograd#

简而言之:

  • 当您使用 PyTorch 对任何具有复数域和/或值域的函数 f(z)f(z) 进行微分时,梯度是根据函数是更大实值损失函数 g(input)=Lg(input)=L 的一部分来计算的。计算出的梯度是 Lz\frac{\partial L}{\partial z^*}(注意 z 的共轭),其负值正是梯度下降算法使用的最陡下降方向。因此,使现有优化器能够直接与复数参数协同工作的路径是可行的。

  • 此约定与 TensorFlow 的复数微分约定一致,但与 JAX 不同(JAX 计算 Lz\frac{\partial L}{\partial z})。

  • 如果您有一个内部使用复数运算的实值实值函数,那么此约定无关紧要:您将始终获得如果仅使用实数运算实现该函数所得到的结果。

如果您对数学细节感兴趣,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。

什么是复数导数?#

复数可微的数学定义是对导数的极限定义进行泛化,使其能够处理复数。考虑一个函数 f:CCf: ℂ → ℂ,

f(z=x+yj)=u(x,y)+v(x,y)jf(z=x+yj) = u(x, y) + v(x, y)j

其中 uuvv 是两个变量实值函数,jj 是虚数单位。

使用导数定义,我们可以写出:

f(z)=limh0,hCf(z+h)f(z)hf'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}

为了使这个极限存在,不仅 uuvv 必须是实可微的,而且 ff 还必须满足柯西-黎曼方程。换句话说:真实和虚部步骤的极限(hh)必须相等。这是一个更严格的条件。

复数可微函数通常被称为全纯函数。它们非常规整,具有您从实值可微函数中学到的所有良好特性,但在优化世界中几乎没有用处。对于优化问题,研究界通常只使用实值目标函数,因为复数不属于任何有序域,因此具有复数值损失的意义不大。

事实证明,没有任何有趣的实值目标函数满足柯西-黎曼方程。因此,全纯函数的理论不能用于优化,因此大多数人使用维尔廷格演算。

维尔廷格演算出现于...#

因此,我们拥有这套出色的复数可微性和全纯函数理论,而我们却无法利用它,因为许多常用函数并非全纯。可怜的数学家该怎么办?维尔廷格观察到,即使 f(z)f(z) 不是全纯的,也可以将其重写为双变量函数 f(z,z)f(z, z*),该函数总是全纯的。这是因为 zz 分量的实部和虚部可以表示为 zzzz^* 的形式:

Re(z)=z+z2Im(z)=zz2j\begin{aligned} \mathrm{Re}(z) &= \frac {z + z^*}{2} \\ \mathrm{Im}(z) &= \frac {z - z^*}{2j} \end{aligned}

维尔廷格演算建议研究 f(z,z)f(z, z^*),如果 ff 是实可微的,则保证是全纯的(另一种思考方式是将其视为坐标系变换,从 f(x,y)f(x, y)f(z,z)f(z, z^*)

x=zxz+zxz=z+zy=zyz+zyz=1j(zz)\begin{aligned} \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ \\ \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ &= 1j * \left(\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}\right) \end{aligned}

从上述方程中,我们得到:

z=1/2(x1jy)z=1/2(x+1jy)\begin{aligned} \frac{\partial }{\partial z} &= 1/2 * \left(\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}\right) \\ \frac{\partial }{\partial z^*} &= 1/2 * \left(\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}\right) \end{aligned}

这是您会在Wikipedia上找到的经典的Wirtinger微积分定义。

这种改变带来了许多优美的推论。

  • 例如,柯西-黎曼方程可以被简化为仅仅说明fz=0\frac{\partial f}{\partial z^*} = 0 (也就是说,函数ff可以完全用zz来表示,而不需要引用zz^*).

  • 另一个重要的(有时也是反直觉的)结果是,正如我们稍后将看到的,当我们对实值损失函数进行优化时,在进行变量更新时应采取的步骤由Lossz\frac{\partial Loss}{\partial z^*}(而不是Lossz\frac{\partial Loss}{\partial z})。

欲了解更多信息,请参阅:https://arxiv.org/pdf/0906.4835.pdf

Wirtinger微积分在优化中有何用处?#

音频和其他领域的研究人员更常使用梯度下降来优化具有复杂变量的实值损失函数。通常,这些人将实部和虚部视为可以更新的独立通道。对于步长α/2\alpha/2和损失LL,我们可以写出以下在R2ℝ^2中的方程:

xn+1=xn(α/2)Lxyn+1=yn(α/2)Ly\begin{aligned} x_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} \\ y_{n+1} &= y_n - (\alpha/2) * \frac{\partial L}{\partial y} \end{aligned}

这些方程如何转换到复数空间C

zn+1=xn(α/2)Lx+1j(yn(α/2)Ly)=znα1/2(Lx+jLy)=znαLz\begin{aligned} z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * 1/2 * \left(\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}\right) \\ &= z_n - \alpha * \frac{\partial L}{\partial z^*} \end{aligned}

发生了一件非常有意思的事情:Wirtinger微积分告诉我们,可以将上面的复数变量更新公式简化为仅引用共轭Wirtinger导数Lz\frac{\partial L}{\partial z^*},这给了我们优化中采取的准确步骤。

由于共轭Wirtinger导数给出了实值损失函数的准确优化步骤,PyTorch在对具有实值损失的函数进行微分时,会返回该导数。

PyTorch如何计算共轭Wirtinger导数?#

通常,我们的导数公式以grad_output作为输入,它表示已经计算过的传入的Vector-Jacobian乘积,即Ls\frac{\partial L}{\partial s^*},其中LL是整个计算(产生实值损失)的损失,而ss是我们函数的输出。目标是计算Lz\frac{\partial L}{\partial z^*},其中zz是函数的输入。实际上,在实值损失的情况下,我们只需计算Ls\frac{\partial L}{\partial s^*},尽管链式法则暗示我们还需要访问Ls\frac{\partial L}{\partial s}。如果您想跳过此推导,请查看本节的最后一个方程,然后跳到下一节。

让我们继续使用f:CCf: ℂ → ℂ进行讨论,定义为f(z)=f(x+yj)=u(x,y)+v(x,y)jf(z) = f(x+yj) = u(x, y) + v(x, y)j。如上所述,autograd 的梯度约定侧重于实值损失函数的优化,因此我们假设ff是更大的实值损失函数gg的一部分。使用链式法则,我们可以写出:

(1)#Lz=Luuz+Lvvz\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}

现在使用Wirtinger导数的定义,我们可以写出:

Ls=1/2(LuLvj)Ls=1/2(Lu+Lvj)\begin{aligned} \frac{\partial L}{\partial s} = 1/2 * \left(\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j\right) \\ \frac{\partial L}{\partial s^*} = 1/2 * \left(\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j\right) \end{aligned}

这里应该指出的是,由于uuvv是实函数,并且根据我们假设ff是实值函数的一部分,LL是实数,因此我们有:

(2)#(Ls)=Ls\left( \frac{\partial L}{\partial s} \right)^* = \frac{\partial L}{\partial s^*}

即,Ls\frac{\partial L}{\partial s} 等于grad_outputgrad\_output^*

通过求解上述关于Lu\frac{\partial L}{\partial u}Lv\frac{\partial L}{\partial v},我们得到:

(3)#Lu=Ls+LsLv=1j(LsLs)\begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ \frac{\partial L}{\partial v} = 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) \end{aligned}

(3)代入(1),我们得到:

Lz=(Ls+Ls)uz+1j(LsLs)vz=Ls(uz+vzj)+Ls(uzvzj)=Ls(u+vj)z+Ls(u+vj)z=Lssz+Lssz\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} + 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\ &= \frac{\partial L}{\partial s} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned}

使用 公式 (2),我们得到

(4)#Lz=(Ls)sz+Ls(sz)=(grad_output)sz+grad_output(sz)\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s^*}\right)^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \left(\frac{\partial s}{\partial z}\right)^* \\ &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * \left(\frac{\partial s}{\partial z}\right)^* } \\ \end{aligned}

最后一个方程对于编写你自己的梯度很重要,因为它将我们的导数公式分解为一个易于手工计算的更简单的公式。

我该如何写一个复数函数的导数公式?#

上述带框的方程给出了所有复数函数导数的通用公式。然而,我们仍然需要计算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*}。有两种方法可以做到这一点:

  • 第一种方法是直接使用 Wirtinger 导数的定义来计算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*} (使用 sx\frac{\partial s}{\partial x}sy\frac{\partial s}{\partial y} (可以按常规方式计算)。

  • 第二种方法是使用变量替换技巧,将 f(z)f(z) 重写为一个二元函数 f(z,z)f(z, z^*),并通过将 zzzz^* 视为独立变量来计算共轭 Wirtinger 导数。这通常更容易;例如,如果所讨论的函数是全纯的,则只会用到 zz (而 sz\frac{\partial s}{\partial z^*} 将为零)。

让我们以 f(z=x+yj)=cz=c(x+yj)f(z = x + yj) = c * z = c * (x+yj) 作为示例,其中 cRc \in ℝ

使用第一种方法计算 Wirtinger 导数,我们得到:

sz=1/2(sxsyj)=1/2(c(c1j)1j)=csz=1/2(sx+syj)=1/2(c+(c1j)1j)=0\begin{aligned} \frac{\partial s}{\partial z} &= 1/2 * \left(\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c - (c * 1j) * 1j) \\ &= c \\ \\ \\ \frac{\partial s}{\partial z^*} &= 1/2 * \left(\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c + (c * 1j) * 1j) \\ &= 0 \\ \end{aligned}

使用 公式 (4),以及 grad_output = 1.0 (这是 PyTorch 中调用 backward() 时标量输出的默认梯度输出值),我们得到

Lz=10+1c=c\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c

使用第二种方法计算 Wirtinger 导数,我们直接得到

sz=(cz)z=csz=(cz)z=0\begin{aligned} \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ &= c \\ \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ &= 0 \end{aligned}

And using (4) again, we get Lz=c\frac{\partial L}{\partial z^*} = c. As you can see, the second way involves lesser calculations, and comes in more handy for faster calculations.

What about cross-domain functions?#

Some functions map from complex inputs to real outputs, or vice versa. These functions form a special case of (4), which we can derive using the chain rule

  • For f:CRf: ℂ → ℝ, we get

    Lz=2grad_outputsz\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}
  • f:RCf: ℝ → ℂ, we get

    Lz=2Re(grad_outputsz)\frac{\partial L}{\partial z^*} = 2 * \mathrm{Re}(grad\_output^* * \frac{\partial s}{\partial z^{*}})

Hooks for saved tensors#

You can control how saved tensors are packed / unpacked by defining a pair of pack_hook / unpack_hook hooks. The pack_hook function should take a tensor as its single argument but can return any python object (e.g. another tensor, a tuple, or even a string containing a filename). The unpack_hook function takes as its single argument the output of pack_hook and should return a tensor to be used in the backward pass. The tensor returned by unpack_hook only needs to have the same content as the tensor passed as input to pack_hook. In particular, any autograd-related metadata can be ignored as they will be overwritten during unpacking.

An example of such pair is

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name)

Notice that the unpack_hook should not delete the temporary file because it might be called multiple times: the temporary file should be alive for as long as the returned SelfDeletingTempFile object is alive. In the above example, we prevent leaking the temporary file by closing it when it is no longer needed (on deletion of the SelfDeletingTempFile object).

注意

We guarantee that pack_hook will only be called once but unpack_hook can be called as many times as the backward pass requires it and we expect it to return the same data each time.

警告

Performing inplace operations on the input of any of the functions is forbidden as they may lead to unexpected side-effects. PyTorch will throw an error if the input to a pack hook is modified inplace but does not catch the case where the input to an unpack hook is modified inplace.

Registering hooks for a saved tensor#

You can register a pair of hooks on a saved tensor by calling the register_hooks() method on a SavedTensor object. Those objects are exposed as attributes of a grad_fn and start with the _raw_saved_ prefix.

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)

The pack_hook method is called as soon as the pair is registered. The unpack_hook method is called each time the saved tensor needs to be accessed, either by means of y.grad_fn._saved_self or during the backward pass.

警告

If you maintain a reference to a SavedTensor after the saved tensors have been released (i.e. after backward has been called), calling its register_hooks() is forbidden. PyTorch will throw an error most of the time but it may fail to do so in some cases and undefined behavior may arise.

Registering default hooks for saved tensors#

Alternatively, you can use the context-manager saved_tensors_hooks to register a pair of hooks which will be applied to all saved tensors that are created in that context.

示例

# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x.detach()
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class Model(nn.Module):
    def forward(self, x):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
          # ... compute output
          output = x
        return output

model = Model()
net = nn.DataParallel(model)

The hooks defined with this context manager are thread-local. Hence, the following code will not produce the desired effects because the hooks do not go through DataParallel.

# Example what NOT to do

net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    output = net(input)

Note that using those hooks disables all the optimization in place to reduce Tensor object creation. For example

with torch.autograd.graph.saved_tensors_hooks(lambda x: x.detach(), lambda x: x):
    x = torch.randn(5, requires_grad=True)
    y = x * x

Without the hooks, x, y.grad_fn._saved_self and y.grad_fn._saved_other all refer to the same tensor object. With the hooks, PyTorch will pack and unpack x into two new tensor objects that share the same storage with the original x (no copy performed).

Backward Hooks execution#

This section will discuss when different hooks fire or don’t fire. Then it will discuss the order in which they are fired. The hooks that will be covered are: backward hooks registered to Tensor via torch.Tensor.register_hook(), post-accumulate-grad hooks registered to Tensor via torch.Tensor.register_post_accumulate_grad_hook(), post-hooks registered to Node via torch.autograd.graph.Node.register_hook(), and pre-hooks registered to Node via torch.autograd.graph.Node.register_prehook().

Whether a particular hook will be fired#

Hooks registered to a Tensor via torch.Tensor.register_hook() are executed when gradients are being computed for that Tensor. (Note that this does not require the Tensor’s grad_fn to be executed. For example, if the Tensor is passed as part of the inputs argument to torch.autograd.grad(), the Tensor’s grad_fn may not be executed, but the hook register to that Tensor will always be executed.)

Hooks registered to a Tensor via torch.Tensor.register_post_accumulate_grad_hook() are executed after the gradients have been accumulated for that Tensor, meaning the Tensor’s grad field has been set. Whereas hooks registered via torch.Tensor.register_hook() are run as gradients are being computed, hooks registered via torch.Tensor.register_post_accumulate_grad_hook() are only triggered once the Tensor’s grad field is updated by autograd at the end of the backward pass. Thus, post-accumulate-grad hooks can only be registered for leaf Tensors. Registering a hook via torch.Tensor.register_post_accumulate_grad_hook() on a non-leaf Tensor will error, even if you call backward(retain_graph=True).

Hooks registered to torch.autograd.graph.Node using torch.autograd.graph.Node.register_hook() or torch.autograd.graph.Node.register_prehook() are only fired if the Node it was registered to is executed.

Whether a particular Node is executed may depend on whether the backward pass was called with torch.autograd.grad() or torch.autograd.backward(). Specifically, you should be aware of these differences when you register a hook on a Node corresponding to a Tensor that you are passing to torch.autograd.grad() or torch.autograd.backward() as part of the inputs argument.

If you are using torch.autograd.backward(), all of the above mentioned hooks will be executed, whether or not you specified the inputs argument. This is because .backward() executes all Nodes, even if they correspond to a Tensor specified as an input. (Note that the execution of this additional Node corresponding to Tensors passed as inputs is usually unnecessary, but done anyway. This behavior is subject to change; you should not depend on it.)

On the other hand, if you are using torch.autograd.grad(), the backward hooks registered to Nodes that correspond to the Tensors passed to input may not be executed, because those Nodes will not be executed unless there is another input that depends on the gradient result of this Node.

The order in which the different hooks are fired#

The order in which things happen are

  1. hooks registered to Tensor are executed

  2. pre-hooks registered to Node are executed (if Node is executed).

  3. the .grad field is updated for Tensors that retain_grad

  4. Node is executed (subject to rules above)

  5. for leaf Tensors that have .grad accumulated, post-accumulate-grad hooks are executed

  6. post-hooks registered to Node are executed (if Node is executed)

If multiple hooks of the same type are registered on the same Tensor or Node they are executed in the order in which they are registered. Hooks that are executed later can observe the modifications to the gradient made by earlier hooks.

Special hooks#

torch.autograd.graph.register_multi_grad_hook() is implemented using hooks registered to Tensors. Each individual Tensor hook is fired following the Tensor hook ordering defined above and the registered multi-grad hook is called when the last Tensor gradient is computed.

torch.nn.modules.module.register_module_full_backward_hook() is implemented using hooks registered to Node. As the forward is computed, hooks are registered to grad_fn corresponding to the inputs and outputs of the module. Because a module may take multiple inputs and return multiple outputs, a dummy custom autograd Function is first applied to the inputs of the module before forward and the outputs of the module before the output of forward is returned to ensure that those Tensors share a single grad_fn, which we can then attach our hooks to.

Behavior of Tensor hooks when Tensor is modified in-place#

Usually hooks registered to a Tensor receive the gradient of the outputs with respect to that Tensor, where the value of the Tensor is taken to be its value at the time backward is computed.

However, if you register hooks to a Tensor, and then modify that Tensor in-place, hooks registered before in-place modification similarly receive gradients of the outputs with respect to the Tensor, but the value of the Tensor is taken to be its value before in-place modification.

If you prefer the behavior in the former case, you should register them to the Tensor after all in-place modifications to it have been made. For example

t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()

Furthermore, it can be helpful to know that under the hood, when hooks are registered to a Tensor, they actually become permanently bound to the grad_fn of that Tensor, so if that Tensor is then modified in-place, even though the Tensor now has a new grad_fn, hooks registered before it was modified in-place will continue to be associated with the old grad_fn, e.g. they will fire when that Tensor’s old grad_fn is reached in the graph by the autograd engine.