评价此页

torch.func.hessian#

torch.func.hessian(func, argnums=0)[source]#

计算 func 关于在索引 argnum 处的参数的 Hessian,通过前向-反向组合策略实现。

前向-反向组合策略(即 jacfwd(jacrev(func)))是实现良好性能的默认选择。也可以通过 jacfwd()jacrev() 的其他组合来计算 Hessian,例如 jacfwd(jacfwd(func))jacrev(jacrev(func))

参数
  • func (function) – 一个 Python 函数,接收一个或多个参数,其中至少一个必须是 Tensor,并返回一个或多个 Tensor。

  • argnums (intTuple[int]) – 可选,整数或整数元组,表示要计算 Hessian 的参数索引。默认为 0。

返回

返回一个函数,该函数接收与 func 相同的输入,并返回 func 相对于 argnums 指定的参数的 Hessian。

注意

您可能会看到此 API 出现“不支持算子 X 的前向模式 AD”的错误。如果出现这种情况,请提交 bug 报告,我们将优先处理。另一种方法是使用 jacrev(jacrev(func)),它具有更好的算子覆盖范围。

对于 R^N -> R^1 函数的基本用法,会得到一个 N x N 的 Hessian。

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))