评价此页

torch.autograd.gradcheck.gradcheck#

torch.autograd.gradcheck.gradcheck(func, inputs, *, eps=1e-06, atol=1e-05, rtol=0.001, raise_exception=True, nondet_tol=0.0, check_undefined_grad=True, check_grad_dtypes=False, check_batched_grad=False, check_batched_forward_grad=False, check_forward_ad=False, check_backward_ad=True, fast_mode=False, masked=None)[源代码]#

检查通过有限差分计算的梯度与 inputs 中是浮点或复数类型且 requires_grad=True 的张量的解析梯度是否一致。

数值梯度和解析梯度之间的检查使用 allclose()

对于我们为优化目的考虑的大多数复杂函数,不存在雅可比矩阵的概念。取而代之的是,gradcheck 验证 Wirtinger 和共轭 Wirtinger 导数的数值和解析值是否一致。因为梯度计算是在假设整个函数具有实值输出的条件下进行的,所以我们以特殊方式处理具有复数输出的函数。对于这些函数,gradcheck 应用于两个实值函数:第一个函数对应于取复数输出的实部,第二个函数对应于取复数输出的虚部。有关更多详细信息,请参阅 复数自动微分

注意

默认值是为双精度 input 设计的。如果 input 的精度较低(例如,FloatTensor),此检查很可能会失败。

注意

在非可微点上求值时,gradcheck 可能会失败,因为通过有限差分计算的数值梯度可能与解析计算的梯度不同(不一定是因为其中一个不正确)。有关更多背景信息,请参阅 非可微函数的梯度

警告

如果 input 中任何被检查的张量具有重叠内存,即不同的索引指向相同的内存地址(例如,来自 torch.Tensor.expand()),此检查很可能会失败,因为在这些索引上通过点扰动计算的数值梯度会改变共享相同内存地址的所有其他索引的值。

参数
  • func (function) – 一个接受 Tensor 输入并返回 Tensor 或 Tensor 元组的 Python 函数

  • inputs (tuple of Tensor or Tensor) – 函数的输入

  • eps (float, optional) – 有限差分的扰动

  • atol (float, optional) – 绝对容差

  • rtol (float, optional) – 相对容差

  • raise_exception (bool, optional) – 指示在检查失败时是否引发异常。异常提供了有关失败确切性质的更多信息。这在调试 gradchecks 时非常有用。

  • nondet_tol (float, optional) – 非确定性容差。当使用相同的输入运行微分时,结果必须完全匹配(默认值 0.0)或在此容差范围内。

  • check_undefined_grad (bool, optional) – 如果为 True,则检查是否支持未定义的输出梯度,并将其视为零,用于 Tensor 输出。

  • check_batched_grad (bool, optional) – 如果为 True,则检查是否可以使用原型的 vmap 支持来计算批处理梯度。默认为 False。

  • check_batched_forward_grad (bool, optional) – 如果为 True,则检查是否可以使用前向 AD 和原型的 vmap 支持来计算批处理前向梯度。默认为 False

  • check_forward_ad (bool, optional) – 如果为 True,则检查使用前向模式 AD 计算的梯度是否与数值梯度匹配。默认为 False

  • check_backward_ad (bool, optional) – 如果为 False,则不执行任何依赖于后向模式 AD 实现的检查。默认为 True

  • fast_mode (bool, optional) – gradcheck 和 gradgradcheck 的快速模式目前仅针对 R 到 R 函数实现。如果输入和输出都不是复数,则运行一个更快的 gradcheck 实现,该实现不再计算整个雅可比矩阵;否则,将回退到慢速实现。

  • masked (bool, optional) – 如果为 True,则忽略稀疏张量中未指定元素的梯度。默认为 False

返回

如果所有差异都满足 allclose 条件,则为 True

返回类型

布尔值