Gradcheck 机制#
创建时间: Apr 27, 2021 | 最后更新时间: Jun 18, 2025
本文档概述了 gradcheck() 和 gradgradcheck() 函数的工作原理。
它将涵盖实值和复值函数的正向和反向模式自动微分 (AD),以及高阶导数。本文档还涵盖了 gradcheck 的默认行为以及传递了 fast_mode=True 参数的情况(下文称为快速 gradcheck)。
符号和背景信息#
在本文档中,我们将使用以下约定:
, , , , , , 和 是实值向量,而 是一个复值向量,可以重写为两个实值向量的形式,即 。
和 是我们分别用于输入和输出空间维度的两个整数。
是我们的基本实到实函数,使得 。
是我们的基本复到实函数,使得 。
对于简单的实到实情况,我们用 来表示与 相关的雅可比矩阵,其大小为 。此矩阵包含所有偏导数,其中位置 处的元素包含 。反向模式 AD 接着计算给定向量 (大小为 )的量 . 另一方面,正向模式 AD 计算给定向量 (大小为 )的量 。
对于包含复数值的函数,情况要复杂得多。我们只在这里提供要点,完整描述可以在 复数自动微分 中找到。
为满足复可微性(柯西-黎曼方程)而设的约束条件对于所有实值损失函数来说都过于严格,因此我们选择使用 Wirtinger 微积分。在 Wirtinger 微积分的基本设置中,链式法则需要访问 Wirtinger 导数(下文称为 )和共轭 Wirtinger 导数(下文称为 )。 和 都需要被反向传播,因为一般来说,尽管有其名称,其中一个并非另一个的复共轭。
为了避免传播这两个值,对于反向模式 AD,我们始终假设要求导数的函数是实值函数或更大的实值函数的一部分。此假设意味着我们在反向传播过程中计算的所有中间梯度也与实值函数相关联。在实践中,这种假设在进行优化时并不具有限制性,因为此类问题需要实值目标(因为复数没有自然顺序)。
在此假设下,使用 和 定义,我们可以证明 (我们在此处使用 表示复共轭),因此只需要“反向传播”其中一个值,因为另一个可以很容易地恢复。为了简化内部计算,当用户要求梯度时,PyTorch 使用 作为反向传播并返回的值。与实数情况类似,当输出实际上在 中时,反向模式 AD 不计算 ,而是仅计算 给定向量 。
对于正向模式 AD,我们使用类似的逻辑,在这种情况下,假设该函数是更大函数的一部分,其输入在 中。在此假设下,我们可以做出类似的声明,即每个中间结果都对应一个输入在 中的函数,在这种情况下,使用 和 定义,我们可以证明对于中间函数有 。为了确保在单维函数的简单情况下,正向和反向模式计算相同的量,正向模式也计算 。与实数情况类似,当输入实际上在 中时,正向模式 AD 不计算 ,而是仅计算 给定向量 。
默认反向模式 gradcheck 行为#
实到实函数#
为了测试一个函数 , 我们通过解析和数值两种方式重构出完整的雅可比矩阵 ,其大小为 。解析版本使用我们的后向模式自动微分,而数值版本则使用有限差分。然后对两个重构的雅可比矩阵进行逐元素比较以验证其相等性。
默认实值输入数值评估#
如果我们考虑一个一维函数()的基本情况,那么我们可以使用来自维基百科文章的基本有限差分公式。我们使用“中心差分”以获得更好的数值性质
这个公式很容易推广到多输出()的情况,方法是让 成为一个大小为 的列向量,例如 。在这种情况下,上述公式可以照常使用,并且仅通过两次用户函数评估(即 和 )来近似整个雅可比矩阵。
处理多输入()的情况在计算上更昂贵。在这种情况下,我们逐一循环所有输入,并对 的每个元素依次应用 扰动。这使我们能够逐列重构 矩阵。
默认实值输入解析评估#
对于解析评估,我们利用了上述描述的后向模式自动微分计算 的事实。对于单输出函数,我们仅使用 来通过一次后向传递恢复完整的雅可比矩阵。
对于多输出函数,我们采用一个 for 循环,该循环遍历各个输出,其中每个 是一个对应于每个输出的独热向量。这使得我们能够逐行重构 矩阵。
复数到实数函数#
为了测试一个函数 ,其中 ,我们重构出包含 的(复数值)矩阵。
默认复数值输入数值评估#
首先考虑 的基本情况。我们从(第 3 章)这篇研究论文中得知:
请注意,在上面的等式中, 和 是 导数。为了进行数值评估,我们采用了前面针对实到实情况描述的方法。这使得我们能够计算 矩阵,然后将其乘以 。
请注意,在撰写本文时,代码以一种稍微迂回的方式计算此值。
# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above
ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()
# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.
默认复数值输入解析评估#
由于后向模式自动微分已经计算出恰好是 的两倍的导数,我们在此仅使用与实到实情况相同的技巧,并在存在多个实输出时逐行重构矩阵。
具有复数输出的函数#
在此情况下,用户提供的函数不遵循自动微分(autograd)关于我们计算反向自动微分的函数是实值函数的假设。这意味着直接在 autograd 上使用此函数是没有明确定义的。为了解决这个问题,我们将替换函数 (其中 可以是 或 ),通过两个函数: 和 ,使得
其中 。然后我们对 和 进行基本的 gradcheck,根据 的情况,使用上面描述的实到实或复到实的情况。
请注意,在撰写本文时,代码并没有显式创建这些函数,而是通过手动传递 参数到不同的函数来执行链式法则。 时,我们考虑的是 。当 时,我们考虑的是 。
快速反向模式 gradcheck#
虽然上面的 gradcheck 公式在保证正确性和可调试性方面都很好,但由于它重建了完整的雅可比矩阵,所以速度非常慢。本节将介绍一种更快地执行 gradcheck 的方法,而不会影响其正确性。当检测到错误时,可以通过添加特殊逻辑来恢复可调试性。在这种情况下,我们可以运行默认版本,该版本会重建完整矩阵,以便向用户提供完整的详细信息。
这里的高层策略是找到一个标量量,该标量量可以通过数值和分析方法高效计算,并且能够充分代表慢速 gradcheck 计算出的完整矩阵,以确保它能够捕捉到雅可比矩阵中的任何差异。
实到实函数的快速 gradcheck#
我们在这里要计算的标量量是 ,其中 是给定的随机向量, 是单位范数随机向量。
对于数值评估,我们可以有效地计算:
然后我们计算这个向量与 的点积,以获得所需的标量值。
对于分析版本,我们可以使用反向模式 AD 来直接计算 。然后与 进行点积以获得期望值。
复到实函数的快速 gradcheck#
与实到实情况类似,我们也想进行矩阵的降维。但 矩阵是复数值的,因此在这种情况下,我们将与复标量进行比较。
由于数值情况下的计算效率存在一些限制,并且为了尽量减少数值评估的次数,我们计算以下(尽管令人惊讶的)标量值:
其中 , 和 。
快速复数输入数值评估#
我们首先考虑如何使用数值方法计算 。为此,我们考虑 ,其中 ,并且 ,我们将其重写如下:
在此公式中,我们可以看到 和 可以像实数到实数情况下的快速版本一样进行计算。一旦计算出这些实数值量,我们就可以重建右侧的复数向量,并与实数向量 进行点积。
快速复数输入解析求值#
对于解析情况,事情会更简单,我们将公式重写为:
因此,我们可以利用反向模式 AD 提供了一种计算 的高效方法,然后将实部与 进行点积,虚部与 进行点积,最后重构出最终的复标量 。
为何不使用复数 #
此时,你可能会想,为什么我们不选择一个复数 并且直接执行 . 为了深入探讨这一点,在这一段中,我们将使用 的复数版本,记为 . 使用这种复数 ,问题在于进行数值评估时,我们需要计算
这将需要四次实数到实数的有限差分评估(是上述提出的方法的两倍)。由于这种方法没有更多的自由度(实值变量的数量相同),并且我们试图在这里获得最快的可能评估,因此我们使用了上面的另一种表述。
带复数输出函数的快速 gradcheck#
与慢速情况一样,我们考虑两个实值函数,并为每个函数使用上述的相应规则。
Gradgradcheck 实现#
PyTorch 还提供了一个用于验证二阶梯度的实用程序。这里的目标是确保反向实现的正确可微性并计算出正确的值。
此功能通过考虑函数 来实现,并在此函数上使用上面定义的 gradcheck。请注意,在这种情况下, 只是一个与 具有相同类型的随机向量。
gradgradcheck 的快速版本是通过对同一函数 使用快速版本的 gradcheck 来实现的。