torch.func.jvp#
- torch.func.jvp(func, primals, tangents, *, strict=False, has_aux=False)[源码]#
代表雅可比矩阵-向量积,返回一个元组,其中包含 func(*primals) 的输出以及“在
primals处计算的func的雅可比矩阵”与tangents的乘积。这也被称为前向模式自动微分。- 参数:
func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors
primals (Tensors) – 传递给
func的位置参数,所有这些参数都必须是 Tensor。返回的函数还将计算相对于这些参数的导数。tangents (Tensors) – 计算雅可比矩阵-向量积的“向量”。必须与
func的输入具有相同的结构和大小。has_aux (bool) – 指示
func返回一个(output, aux)元组的标志,其中第一个元素是待微分函数的输出,第二个元素是其他不会被微分的辅助对象。默认值:False。
- 返回:
返回一个
(output, jvp_out)元组,其中包含func在primals处计算的输出以及雅可比矩阵-向量积。如果has_aux 为 True,则返回一个(output, jvp_out, aux)元组。
注意
您可能会看到此 API 报错“forward-mode AD not implemented for operator X”。如果发生这种情况,请提交 bug 报告,我们将优先处理。
当您希望计算从 R^1 到 R^N 的函数的梯度时,jvp 非常有用。
>>> from torch.func import jvp >>> x = torch.randn([]) >>> f = lambda x: x * torch.tensor([1.0, 2.0, 3]) >>> value, grad = jvp(f, (x,), (torch.tensor(1.0),)) >>> assert torch.allclose(value, f(x)) >>> assert torch.allclose(grad, torch.tensor([1.0, 2, 3]))
jvp()可以通过为每个输入传递切线来支持具有多个输入的函数。>>> from torch.func import jvp >>> x = torch.randn(5) >>> y = torch.randn(5) >>> f = lambda x, y: (x * y) >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) >>> assert torch.allclose(output, x + y)