NestedIOFunction#
- class torch.autograd.function.NestedIOFunction(*args, **kwargs)[source]#
This class is here only for backward compatibility reasons. Use
Function
instead of this for any new use case.- static jvp(ctx, *grad_inputs)[source]#
定义使用前向模式自动微分来区分操作的公式。
This function is to be overridden by all subclasses. It must accept a context
ctx
as the first argument, followed by as many inputs as theforward()
got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs toforward()
. Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.You can use the
ctx
object to pass any value from the forward to this functions.- 返回类型
- save_for_forward(*tensors)[source]#
Save given tensors for a future call to
jvp()
.save_for_forward
should be called at most once, in either thesetup_context()
orforward()
methods, and all arguments should be tensors.In
jvp()
, saved objects can be accessed through thesaved_tensors
attribute.参数也可以是
None
。这不会执行任何操作。有关如何使用此方法的更多详细信息,请参阅 扩展 torch.autograd。
示例
>>> class Func(torch.autograd.Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> ctx.save_for_backward(x, y) >>> ctx.save_for_forward(x, y) >>> ctx.z = z >>> return x * y * z >>> >>> @staticmethod >>> def jvp(ctx, x_t, y_t, _): >>> x, y = ctx.saved_tensors >>> z = ctx.z >>> return z * (y * x_t + x * y_t) >>> >>> @staticmethod >>> def vjp(ctx, grad_out): >>> x, y = ctx.saved_tensors >>> z = ctx.z >>> return z * grad_out * y, z * grad_out * x, None >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> t = torch.tensor(1., dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> >>> with fwAD.dual_level(): >>> a_dual = fwAD.make_dual(a, t) >>> d = Func.apply(a_dual, b, c)
- property saved_tensors#
See
Function.saved_tensors()
.
- set_materialize_grads(value)[source]#
Set whether to materialize grad tensors. Default is
True
.This should be called only from either the
setup_context()
orforward()
methods.If
True
, undefined grad tensors will be expanded to tensors full of zeros prior to calling thebackward()
andjvp()
methods.示例
>>> class SimpleFunc(Function): >>> @staticmethod >>> def forward(ctx, x): >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> return g1 + g2 # No check for None necessary >>> >>> # We modify SimpleFunc to handle non-materialized grad outputs >>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x): >>> ctx.set_materialize_grads(False) >>> ctx.save_for_backward(x) >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> x, = ctx.saved_tensors >>> grad_input = torch.zeros_like(x) >>> if g1 is not None: # We must check for None now >>> grad_input += g1 >>> if g2 is not None: >>> grad_input += g2 >>> return grad_input >>> >>> a = torch.tensor(1., requires_grad=True) >>> b, _ = Func.apply(a) # induces g2 to be undefined
- static setup_context(ctx, inputs, output)[source]#
There are two ways to define the forward pass of an autograd.Function.
Either
Override forward with the signature
forward(ctx, *args, **kwargs)
.setup_context
is not overridden. Setting up the ctx for backward happens inside theforward
.Override forward with the signature
forward(*args, **kwargs)
and overridesetup_context
. Setting up the ctx for backward happens insidesetup_context
(as opposed to inside theforward
)
See
torch.autograd.Function.forward()
and Extending torch.autograd for more details.- 返回类型
- static vjp(ctx, *grad_outputs)[source]#
定义使用反向模式自动微分来区分操作的公式。
此函数应被所有子类重写。(定义此函数等同于定义
vjp
函数。)它必须接受一个上下文
ctx
作为第一个参数,后面跟着forward()
返回的任意数量的输出(对于 forward 函数的非张量输出将传递 None),并且它应该返回与forward()
中的输入数量相同的张量。每个参数是相对于给定输出的梯度,每个返回值应该是相对于相应输入的梯度。如果一个输入不是张量或是一个不需要梯度的张量,你可以为该输入传递 None 作为梯度。上下文可用于检索在 forward 传播中保存的张量。它还有一个属性
ctx.needs_input_grad
,这是一个布尔元组,表示每个输入是否需要梯度。例如,如果forward()
的第一个输入需要相对于输出计算梯度,那么backward()
将具有ctx.needs_input_grad[0] = True
。- 返回类型
- static vmap(info, in_dims, *args)[source]#
定义此 autograd.Function 在
torch.vmap()
下的行为。要使
torch.autograd.Function()
支持torch.vmap()
,您必须重写此静态方法,或者将generate_vmap_rule
设置为True
(您不能同时执行两者)。如果您选择重写此静态方法:它必须接受
一个
info
对象作为第一个参数。info.batch_size
指定了正在 vmap 的维度的大小,而info.randomness
是传递给torch.vmap()
的随机性选项。一个
in_dims
元组作为第二个参数。对于args
中的每个 arg,in_dims
都有一个相应的Optional[int]
。如果 arg 不是张量,或者 arg 没有被 vmap,则它是None
,否则,它是一个整数,指定了张量的哪个维度正在被 vmap。*args
,这与forward()
的 args 相同。
vmap 静态方法的返回值是一个元组
(output, out_dims)
。与in_dims
类似,out_dims
的结构应与output
相同,并为每个输出包含一个out_dim
,指定输出是否具有 vmap 的维度以及它在其中的索引。有关更多详细信息,请参阅 使用 autograd.Function 扩展 torch.func。