functorch.vjp¶
-
functorch.vjp(func, *primals, has_aux=False)[源代码]¶ 代表向量-雅可比积,返回一个元组,其中包含应用于
primals的func的结果,以及一个函数,当给定cotangents时,计算func关于primals的反向模式雅可比矩阵乘以cotangents。- 参数
func (Callable) – 一个接受一个或多个参数的 Python 函数。必须返回一个或多个张量。
primals (张量) –
func的位置参数,必须全部为张量。返回的函数也将计算关于这些参数的导数has_aux (布尔值) – 标志,指示
func返回一个(output, aux)元组,其中第一个元素是待微分的函数的输出,第二个元素是不会被微分的其他辅助对象。默认值:False。
- 返回值
返回一个
(output, vjp_fn)元组,其中包含应用于primals的func的输出,以及一个计算func关于所有primals的 vjp 的函数,使用传递给返回函数的 cotangents。如果has_aux is True,则返回一个(output, vjp_fn, aux)元组。返回的vjp_fn函数将返回每个 VJP 的元组。
在简单情况下使用时,
vjp()的行为与grad()相同>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x))
但是,
vjp()可以通过传入每个输出的 cotangents 来支持具有多个输出的函数>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()甚至可以支持输出为 Python 结构体>>> x = torch.randn([5]) >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
由
vjp()返回的函数将计算关于每个primals的偏导数>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primals是f的位置参数。所有关键字参数都使用其默认值>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
注意
将 PyTorch
torch.no_grad与vjp一起使用。案例 1:在函数内部使用torch.no_grad>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
在这种情况下,
vjp(f)(x)将尊重内部的torch.no_grad。案例 2:在
torch.no_grad上下文管理器内部使用vjp>>> # xdoctest: +SKIP(failing) >>> with torch.no_grad(): >>> vjp(f)(x)
在这种情况下,
vjp将尊重内部的torch.no_grad,但不尊重外部的。这是因为vjp是一个“函数转换”:其结果不应依赖于f外部的上下文管理器的结果。警告
我们已将 functorch 集成到 PyTorch 中。作为集成的最后一步,从 PyTorch 2.0 开始,functorch.vjp 已弃用,并且将在 PyTorch >= 2.3 的未来版本中删除。请改用 torch.func.vjp;有关更多详细信息,请参阅 PyTorch 2.0 发行说明和/或 torch.func 迁移指南 https://pytorch.ac.cn/docs/stable/func.migrating.html