torch.func.hessian#
- torch.func.hessian(func, argnums=0)[源代码]#
通过前向-反向策略,计算
func相对于索引为argnum的参数(们)的 Hessian。正向-逆向策略(组合
jacfwd(jacrev(func)))是获得良好性能的默认选择。也可以通过jacfwd()和jacrev()的其他组合来计算 Hessian,例如jacfwd(jacfwd(func))或jacrev(jacrev(func))。- 参数
- 返回
返回一个函数,该函数接受与
func相同的输入,并返回func相对于argnums指定的一个或多个参数的 Hessian。
注意
您可能会看到此 API 报错“forward-mode AD not implemented for operator X”。如果是这种情况,请提交 bug 报告,我们将优先处理。另一种方法是使用
jacrev(jacrev(func)),它具有更好的运算符覆盖范围。基本用法,对于 R^N -> R^1 函数,得到一个 N x N 的 Hessian
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))