评价此页

Gradcheck 机制#

创建日期: 2021年4月27日 | 最后更新日期: 2025年6月18日

本文档概述了 gradcheck()gradgradcheck() 函数的工作原理。

本文将涵盖实值和复值函数的正向和反向模式自动微分,以及高阶导数。本文档还将涵盖 gradcheck 的默认行为以及传递 fast_mode=True 参数(下文称为快速 gradcheck)的情况。

符号和背景信息#

在本文档中,我们将使用以下约定

  1. xx, yy, aa, bb, vv, uu, ururuiui 是实值向量,而 zz 是一个复值向量,可以表示为两个实值向量的形式:z=a+ibz = a + i b

  2. NNMM 分别是我们用于输入和输出空间的两个整数。

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M 是我们的基本实到实函数,其中 y=f(x)y = f(x)

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M 是我们的基本复到实函数,其中 y=g(z)y = g(z)

对于简单的实到实函数,我们将其雅可比矩阵表示为 JfJ_f,大小为 M×NM \times N。此矩阵包含所有偏导数,其中位置 (i,j)(i, j) 的项是 yixj\frac{\partial y_i}{\partial x_j}. 反向模式自动微分则计算,对于给定的向量 vv,大小为 MM 的量 vTJfv^T J_f. 另一方面,正向模式自动微分计算,对于给定的向量 uu,大小为 NN 的量 JfuJ_f u

对于包含复数值的函数,情况要复杂得多。此处仅提供概要,完整描述请参见 复数自动微分

为了满足复数可微性(柯西-黎曼方程)的约束条件,对所有实值损失函数来说,这些约束条件都过于严格,因此我们采用了维尔丁格演算。在维尔丁格演算的基本设置中,链式法则需要同时访问维尔丁格导数(下文称为 WW)和共轭维尔丁格导数(下文称为 CWCW)。WWCWCW 都需要传播,因为通常情况下,尽管有名称,一个并不是另一个的复共轭。

为了避免传播这两个值,对于反向模式自动微分,我们始终假定正在计算导数的函数要么是实值函数,要么是更大的实值函数的一部分。此假设意味着我们在反向传播过程中计算的所有中间梯度也与实值函数相关联。实际上,此假设在进行优化时并无限制,因为此类问题需要实值目标(因为复数没有自然顺序)。

在此假设下,使用 WWCWCW 的定义,我们可以证明 W=CWW = CW^* (我们在此使用 * 表示复共轭),因此只需要“反向传播通过图”其中一个值,而另一个可以轻松恢复。为了简化内部计算,PyTorch 使用 2CW2 * CW 作为其反向传播并返回值,当用户请求梯度时。与实值情况类似,当输出实际在 RM\mathcal{R}^M 时,反向模式自动微分不会计算 2CW2 * CW,而仅计算 vT(2CW)v^T (2 * CW),其中 vRMv \in \mathcal{R}^M 是给定的向量。

对于正向模式自动微分,我们采用类似的逻辑,在这种情况下,我们假设该函数是更大函数的一部分,该函数的输入在 R\mathcal{R} 中。在此假设下,我们可以得出类似的结论,即每个中间结果都对应一个输入在 R\mathcal{R} 中的函数,并且在这种情况下,使用 WWCWCW 的定义,我们可以证明对于中间函数 W=CWW = CW。为了确保正向和反向模式在单变量函数的基本情况下计算相同的量,正向模式也计算 2CW2 * CW。与实值情况类似,当输入实际在 RN\mathcal{R}^N 中时,正向模式自动微分不计算 2CW2 * CW,而是仅计算 (2CW)u(2 * CW) u,其中 uRNu \in \mathcal{R}^N 是给定的向量。

默认反向模式 gradcheck 行为#

实到实函数#

为了测试函数 f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y,我们通过两种方式重构完整的雅可比矩阵 JfJ_f,大小为 M×NM \times N:一种是解析方法,另一种是数值方法。解析方法使用我们的反向模式自动微分,而数值方法使用有限差分。然后逐个元素地比较两个重构的雅可比矩阵是否相等。

默认实输入数值评估#

如果我们考虑一维函数(N=M=1N = M = 1)的基本情况,那么我们可以使用 维基百科文章 中的基本有限差分公式。我们使用“中心差分”以获得更好的数值特性。

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

该公式易于推广到多输出(M>1M \gt 1),其中 yx\frac{\partial y}{\partial x} 是大小为 M×1M \times 1 的列向量,例如 f(x+eps)f(x + eps)。在这种情况下,上述公式可以原样重用,并且只需对用户函数进行两次评估(即 f(x+eps)f(x + eps)f(xeps)f(x - eps))就可以近似整个雅可比矩阵。

处理多输入(N>1N \gt 1)的情况计算成本更高。在这种情况下,我们逐个循环遍历所有输入,并对 xx 的每个元素依次应用 epseps 扰动。这允许我们逐列重构 JfJ_f 矩阵。

默认实输入解析评估#

对于解析评估,我们利用上面所述的事实,即反向模式自动微分计算 vTJfv^T J_f. 对于只有一个输出的函数,我们直接使用 v=1v = 1 来通过一次反向传播恢复整个雅可比矩阵。

对于有多个输出的函数,我们采用一个 for 循环,该循环遍历每个输出,其中每个 vv 是一个对应于每个输出的独热向量。这允许我们逐行重构 JfJ_f 矩阵。

复到实函数#

为了测试函数 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,我们重构包含 2CW2 * CW 的(复值)矩阵。

默认复数输入数值评估#

首先考虑 N=M=1N = M = 1 的基本情况。我们从 这篇研究论文(第 3 章)得知:

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

请注意,在上述等式中,ya\frac{\partial y}{\partial a}yb\frac{\partial y}{\partial b}RR\mathcal{R} \to \mathcal{R} 导数。为了在数值上计算它们,我们使用上面为实到实情况描述的方法。这允许我们计算 CWCW 矩阵,然后将其乘以 22

请注意,截至撰写本文时,代码以一种略显迂回的方式计算此值。

# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

默认复数输入解析评估#

由于反向模式自动微分已经精确计算了 CWCW 的两倍,因此我们在这里使用了与实到实情况相同的技巧,当有多个实输出时,我们逐行重构矩阵。

具有复数输出的函数#

在这种情况下,用户提供的函数不遵循自动微分的假设,即我们计算反向模式自动微分的函数是实值的。这意味着直接在函数上使用自动微分没有明确定义。为了解决这个问题,我们将测试函数 h:PNCMh: \mathcal{P}^N \to \mathcal{C}^M (其中 P\mathcal{P} 可以是 R\mathcal{R}C\mathcal{C})替换为两个函数:hrhrhihi,使得

我们定义了以下函数:

其中 qPq \in \mathcal{P}。然后,我们将根据上述实到实或复到实的情况,对 hrhrhihi 进行基本的梯度检验,具体取决于 P\mathcal{P}

请注意,截至撰写本文时,代码并未显式创建这些函数,而是通过将 realrealimagimag 参数手动传递给不同的函数来实现链式法则。当 grad_out=1\text{grad\_out} = 1 时,我们考虑 hrhr。当 grad_out=1j\text{grad\_out} = 1j 时,我们考虑 hihi

快速反向模式梯度检验#

虽然上述梯度检验的表述非常有用,可以确保正确性和可调试性,但它非常慢,因为它会重建完整的雅可比矩阵。本节介绍了一种执行梯度检验的更快方法,同时不影响其正确性。通过在检测到错误时添加特殊逻辑,可以恢复可调试性。在这种情况下,我们可以运行默认版本,该版本会重建完整的矩阵,以便向用户提供完整的详细信息。

这里的总体策略是找到一个标量量,该标量量可以通过数值和解析方法高效计算,并且能够充分代表缓慢梯度检验计算的完整矩阵,从而确保它能够捕获雅可比矩阵中的任何差异。

实到实函数的快速梯度检验#

我们想在这里计算的标量量是 vTJfuv^T J_f u,对于给定的随机向量 vRMv \in \mathcal{R}^M 和随机单位范数向量 uRNu \in \mathcal{R}^N

对于数值评估,我们可以高效地计算:

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

然后,我们计算此向量与 vv 的点积,以获得感兴趣的标量值。

对于解析版本,我们可以使用反向模式自动微分来直接计算 vTJfv^T J_f。然后,我们将其与 uu 进行点积以获得期望值。

复到实函数的快速梯度检验#

与实到实情况类似,我们要对完整矩阵进行约简。但是,2CW2 * CW 矩阵是复数值的,因此在这种情况下,我们将比较复标量。

由于在数值情况下高效计算存在一些限制,并为了尽量减少数值评估的次数,我们计算以下(尽管可能令人惊讶的)标量值:

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

其中 vRMv \in \mathcal{R}^MurRNur \in \mathcal{R}^NuiRNui \in \mathcal{R}^N

快速复数输入数值评估#

我们首先考虑如何通过数值方法计算 ss。为此,请记住我们考虑的是 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,并且 CW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}),我们将其重写为:

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

在上面的公式中,我们可以看到 yaur\frac{\partial y}{\partial a} urybui\frac{\partial y}{\partial b} ui 可以像实到实情况的快速版本一样进行评估。一旦计算出这些实值量,我们就可以重建右侧的复向量,并与实值 vv 向量进行点积。

快速复数输入解析评估#

在解析情况下,事情会更简单,我们重写公式为:

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

因此,我们可以利用反向模式自动微分提供一种高效计算 vT(2CW)v^T (2 * CW) 的方法,然后将其实部与 urur 进行点积,虚部与 uiui 进行点积,最后重构出最终的复标量 ss

为什么不使用复数 uu#

此时,您可能会想,为什么我们不选择一个复数 uu 并直接执行约简 2vTCWu2 * v^T CW u'. 为了深入探讨这一点,在本段中,我们将使用复数版本的 uu,记为 u=ur+iuiu' = ur' + i ui'. 使用这样的复数 uu',问题在于进行数值评估时,我们需要计算

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

这就需要四次实到实有限差分评估(是上述方法的两倍)。由于这种方法没有更多的自由度(变量数量相同),并且我们试图在此处实现最快的评估,因此我们使用了上述另一种表述。

具有复数输出的函数的快速梯度检验#

与慢速情况一样,我们考虑两个实值函数,并为每个函数使用上述适当的规则。

二阶梯度检验实现#

PyTorch 还提供了一个验证二阶梯度的实用程序。这里的目标是确保反向实现的微分也是正确的,并且计算结果正确。

此功能通过考虑函数 F:x,vvTJfF: x, v \to v^T J_f and use the gradcheck defined above on this function. Note that vv in this case is just a random vector with the same type as f(x)f(x).

The fast version of gradgradcheck is implemented by using the fast version of gradcheck on that same function FF.