torch.autograd.Function.forward#
- static Function.forward(*args, **kwargs)[source]#
定义自定义自动微分函数的前向传播。
该函数应由所有子类重写。定义 forward 的两种方式
用法 1(组合 forward 和 ctx)
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必须接受一个上下文 ctx 作为第一个参数,后面跟任意数量的参数(张量或其他类型)。
有关更多详细信息,请参阅 组合 forward() 和 setup_context() 或分别定义。
用法 2(分开 forward 和 ctx)
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
forward 不再接受 ctx 参数。
相反,您还必须重写
torch.autograd.Function.setup_context()
静态方法来处理 `ctx` 对象的设置。`output` 是 forward 的输出,`inputs` 是 forward 输入的元组。有关更多详细信息,请参阅 扩展 torch.autograd。
上下文可用于存储可在 backward 过程中检索的任意数据。不应直接将张量存储在 `ctx` 上(尽管目前为兼容性不强制执行)。相反,应使用 `ctx.save_for_backward()` 保存张量(如果它们将用于 `backward`(等效于 `vjp`))或使用 `ctx.save_for_forward()` 保存张量(如果它们将用于 `jvp`)。
- 返回类型