评价此页

torch.func.jvp#

torch.func.jvp(func, primals, tangents, *, strict=False, has_aux=False)[source]#

代表 Jacobian-vector product,返回一个元组,其中包含 func(*primals) 的输出以及在 primals 处计算的“func 的 Jacobian”与 tangents 的乘积。这也被称为前向模式自动微分。

参数
  • func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors

  • primals (Tensors) – 传递给 func 的位置参数,所有这些参数都必须是 Tensor。返回的函数还将计算相对于这些参数的导数。

  • tangents (Tensors) – 用于计算 Jacobian-vector product 的“向量”。其结构和大小必须与 func 的输入相同。

  • has_aux (bool) – 一个标志,指示 func 返回一个 (output, aux) 元组,其中第一个元素是要求导函数的输出,第二个元素是其他不会被求导的辅助对象。默认为 False。

返回

返回一个 (output, jvp_out) 元组,包含 funcprimals 处计算的输出以及 Jacobian-vector product。如果 has_aux True,则返回一个 (output, jvp_out, aux) 元组。

注意

您可能会看到此 API 报错“forward-mode AD not implemented for operator X”。如果出现这种情况,请提交一个 bug 报告,我们将优先处理。

当您希望计算函数 R^1 -> R^N 的梯度时,jvp 非常有用。

>>> from torch.func import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1.0, 2.0, 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.0),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1.0, 2, 3]))

jvp() 可以通过为每个输入传递对应的 tangents 来支持具有多个输入的函数。

>>> from torch.func import jvp
>>> x = torch.randn(5)
>>> y = torch.randn(5)
>>> f = lambda x, y: (x * y)
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
>>> assert torch.allclose(output, x + y)