评价此页

NestedIOFunction#

class torch.autograd.function.NestedIOFunction(*args, **kwargs)[source]#

This class is here only for backward compatibility reasons. Use Function instead of this for any new use case.

backward(*gradients)[source]#

Shared backward utility.

返回类型

任何

backward_extended(*grad_output)[source]#

User defined backward.

forward(*args)[source]#

Shared forward utility.

返回类型

任何

forward_extended(*input)[source]#

User defined forward.

static jvp(ctx, *grad_inputs)[source]#

定义使用前向模式自动微分来区分操作的公式。

This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many inputs as the forward() got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to forward(). Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.

You can use the ctx object to pass any value from the forward to this functions.

返回类型

任何

mark_dirty(*args, **kwargs)[source]#

See Function.mark_dirty().

mark_non_differentiable(*args, **kwargs)[source]#

See Function.mark_non_differentiable().

save_for_backward(*args)[source]#

See Function.save_for_backward().

save_for_forward(*tensors)[source]#

Save given tensors for a future call to jvp().

save_for_forward should be called at most once, in either the setup_context() or forward() methods, and all arguments should be tensors.

In jvp(), saved objects can be accessed through the saved_tensors attribute.

参数也可以是 None。这不会执行任何操作。

有关如何使用此方法的更多详细信息,请参阅 扩展 torch.autograd

示例

>>> class Func(torch.autograd.Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         ctx.save_for_backward(x, y)
>>>         ctx.save_for_forward(x, y)
>>>         ctx.z = z
>>>         return x * y * z
>>>
>>>     @staticmethod
>>>     def jvp(ctx, x_t, y_t, _):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * (y * x_t + x * y_t)
>>>
>>>     @staticmethod
>>>     def vjp(ctx, grad_out):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * grad_out * y, z * grad_out * x, None
>>>
>>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>>     t = torch.tensor(1., dtype=torch.double)
>>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>>     c = 4
>>>
>>>     with fwAD.dual_level():
>>>         a_dual = fwAD.make_dual(a, t)
>>>         d = Func.apply(a_dual, b, c)
property saved_tensors#

See Function.saved_tensors().

set_materialize_grads(value)[source]#

Set whether to materialize grad tensors. Default is True.

This should be called only from either the setup_context() or forward() methods.

If True, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward() and jvp() methods.

示例

>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined
static setup_context(ctx, inputs, output)[source]#

There are two ways to define the forward pass of an autograd.Function.

Either

  1. Override forward with the signature forward(ctx, *args, **kwargs). setup_context is not overridden. Setting up the ctx for backward happens inside the forward.

  2. Override forward with the signature forward(*args, **kwargs) and override setup_context. Setting up the ctx for backward happens inside setup_context (as opposed to inside the forward)

See torch.autograd.Function.forward() and Extending torch.autograd for more details.

返回类型

任何

static vjp(ctx, *grad_outputs)[source]#

定义使用反向模式自动微分来区分操作的公式。

此函数应被所有子类重写。(定义此函数等同于定义 vjp 函数。)

它必须接受一个上下文 ctx 作为第一个参数,后面跟着 forward() 返回的任意数量的输出(对于 forward 函数的非张量输出将传递 None),并且它应该返回与 forward() 中的输入数量相同的张量。每个参数是相对于给定输出的梯度,每个返回值应该是相对于相应输入的梯度。如果一个输入不是张量或是一个不需要梯度的张量,你可以为该输入传递 None 作为梯度。

上下文可用于检索在 forward 传播中保存的张量。它还有一个属性 ctx.needs_input_grad,这是一个布尔元组,表示每个输入是否需要梯度。例如,如果 forward() 的第一个输入需要相对于输出计算梯度,那么 backward() 将具有 ctx.needs_input_grad[0] = True

返回类型

任何

static vmap(info, in_dims, *args)[source]#

定义此 autograd.Function 在 torch.vmap() 下的行为。

要使 torch.autograd.Function() 支持 torch.vmap(),您必须重写此静态方法,或者将 generate_vmap_rule 设置为 True(您不能同时执行两者)。

如果您选择重写此静态方法:它必须接受

  • 一个 info 对象作为第一个参数。info.batch_size 指定了正在 vmap 的维度的大小,而 info.randomness 是传递给 torch.vmap() 的随机性选项。

  • 一个 in_dims 元组作为第二个参数。对于 args 中的每个 arg,in_dims 都有一个相应的 Optional[int]。如果 arg 不是张量,或者 arg 没有被 vmap,则它是 None,否则,它是一个整数,指定了张量的哪个维度正在被 vmap。

  • *args,这与 forward() 的 args 相同。

vmap 静态方法的返回值是一个元组 (output, out_dims)。与 in_dims 类似,out_dims 的结构应与 output 相同,并为每个输出包含一个 out_dim,指定输出是否具有 vmap 的维度以及它在其中的索引。

有关更多详细信息,请参阅 使用 autograd.Function 扩展 torch.func