torch.func.linearize#
- torch.func.linearize(func, *primals)[源代码]#
返回
func在primals处的值以及在primals处的线性近似值。- 参数
func (Callable) – 一个接受一个或多个参数的 Python 函数。
primals (Tensors) –
func的位置参数,它们必须都是 Tensor。这些是函数被线性逼近的值。
- 返回
返回一个
(output, jvp_fn)元组,其中包含func应用于primals后的输出,以及一个用于计算在primals处求值的func的 jvp 的函数。- 返回类型
如果要在
primals处多次计算 jvp,那么linearize会很有用。然而,为了实现这一点,linearize 会保存中间计算,并比直接应用 jvp 有更高的内存要求。因此,如果所有tangents都已知,那么计算 vmap(jvp) 而不是使用 linearize 可能会更有效。注意
linearize会计算func两次。请提交一个 issue 以便实现单次计算。示例
>>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>