评价此页

torch.autograd.Function.forward#

static Function.forward(*args, **kwargs)[source]#

定义自定义自动微分函数的前向传播。

此函数应被所有子类覆盖。定义 forward 的两种方法:

用法 1 (合并 forward 和 ctx)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

用法 2 (分开 forward 和 ctx)

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • forward 不再接受 ctx 参数。

  • 相反,您还必须覆盖 torch.autograd.Function.setup_context() 静态方法来处理 ctx 对象的设置。 output 是 forward 的输出,inputs 是 forward 输入的元组。

  • 有关更多详细信息,请参阅 扩展 torch.autograd

上下文可用于存储可以在反向传播期间检索的任意数据。不应直接在 ctx 上存储张量(尽管出于向后兼容性目前不强制执行此操作)。而是应使用 ctx.save_for_backward() 保存张量(如果打算在 backward(等同于 vjp)中使用),或者使用 ctx.save_for_forward() 保存张量(如果打算在 jvp 中使用)。

返回类型

任何