评价此页

torch.func.grad_and_value#

torch.func.grad_and_value(func, argnums=0, has_aux=False)[源代码]#

返回一个用于计算梯度和原始值(或前向计算)的元组的函数。

参数
  • func (Callable) – 一个接受一个或多个参数的 Python 函数。必须返回一个单元素 Tensor。如果指定了 has_aux 等于 True,则函数可以返回一个单元素 Tensor 和其他辅助对象的元组:(output, aux)

  • argnums (intTuple[int]) – 指定需要计算梯度的参数。 argnums 可以是单个整数或整数元组。默认为:0。

  • has_aux (bool) – 标志,指示 func 返回一个张量和其他辅助对象: (output, aux)。默认为:False。

返回

用于计算其输入和前向计算的梯度元组的函数。默认情况下,函数输出是相对于第一个参数的梯度张量和原始计算的元组。如果指定 has_auxTrue,则返回梯度元组和带有输出辅助对象的前向计算元组。如果 argnums 是整数元组,则返回一个元组,其中包含相对于每个 argnums 值的输出梯度元组以及前向计算。

返回类型

Callable

请参阅 grad() 查看示例。