torch.autograd.Function.backward#
- static Function.backward(ctx, *grad_outputs)[source]#
定义使用反向模式自动微分来区分操作的公式。
此函数应被所有子类重写。(定义此函数等同于定义
vjp
函数。)它必须接受一个上下文
ctx
作为第一个参数,然后是forward()
返回的输出数量(对于 forward 函数的非张量输出将传入 None),并且它应该返回与forward()
的输入数量相等的张量。每个参数是关于给定输出的梯度,并且每个返回值应该是关于相应输入的梯度。如果输入不是张量或是不需要梯度的张量,你可以只为该输入传入 None 作为梯度。上下文可用于检索在 forward 传播期间保存的张量。它还有一个属性
ctx.needs_input_grad
,它是一个布尔值元组,表示每个输入是否需要梯度。例如,如果forward()
的第一个输入需要计算关于输出的梯度,那么backward()
将会有ctx.needs_input_grad[0] = True
。- 返回类型