评价此页

torch.func.grad_and_value#

torch.func.grad_and_value(func, argnums=0, has_aux=False)[source]#

返回一个用于计算梯度和原始值(或前向计算)的元组的函数。

参数
  • func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. If specified has_aux equals True, function can return a tuple of single-element Tensor and other auxiliary objects: (output, aux).

  • argnums (int or Tuple[int]) – Specifies arguments to compute gradients with respect to. argnums can be single integer or tuple of integers. Default: 0.

  • has_aux (bool) – Flag indicating that func returns a tensor and other auxiliary objects: (output, aux). Default: False.

返回

Function to compute a tuple of gradients with respect to its inputs and the forward computation. By default, the output of the function is a tuple of the gradient tensor(s) with respect to the first argument and the primal computation. If specified has_aux equals True, tuple of gradients and tuple of the forward computation with output auxiliary objects is returned. If argnums is a tuple of integers, a tuple of a tuple of the output gradients with respect to each argnums value and the forward computation is returned.

返回类型

Callable

See grad() for examples