评价此页

torch.func.linearize#

torch.func.linearize(func, *primals)[源代码]#

返回 funcprimals 处的值以及在 primals 处的线性近似值。

参数
  • func (Callable) – 一个接受一个或多个参数的 Python 函数。

  • primals (Tensors) – func 的位置参数,它们必须都是 Tensor。这些是函数被线性逼近的值。

返回

返回一个 (output, jvp_fn) 元组,其中包含 func 应用于 primals 后的输出,以及一个用于计算在 primals 处求值的 func 的 jvp 的函数。

返回类型

tuple[Any, Callable]

如果要在 primals 处多次计算 jvp,那么 linearize 会很有用。然而,为了实现这一点,linearize 会保存中间计算,并比直接应用 jvp 有更高的内存要求。因此,如果所有 tangents 都已知,那么计算 vmap(jvp) 而不是使用 linearize 可能会更有效。

注意

linearize 会计算 func 两次。请提交一个 issue 以便实现单次计算。

示例

>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
...
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>>