torch.autograd.Function.forward#
- static Function.forward(*args, **kwargs)[source]#
定义自定义自动微分函数的前向传播。
此函数应被所有子类覆盖。定义 forward 的两种方法:
用法 1 (合并 forward 和 ctx)
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必须接受一个 context ctx 作为第一个参数,后跟任意数量的参数(张量或其他类型)。
有关更多详细信息,请参阅 合并或分开 forward() 和 setup_context()。
用法 2 (分开 forward 和 ctx)
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
forward 不再接受 ctx 参数。
相反,您还必须覆盖
torch.autograd.Function.setup_context()
静态方法来处理ctx
对象的设置。output
是 forward 的输出,inputs
是 forward 输入的元组。有关更多详细信息,请参阅 扩展 torch.autograd。
上下文可用于存储可以在反向传播期间检索的任意数据。不应直接在 ctx 上存储张量(尽管出于向后兼容性目前不强制执行此操作)。而是应使用
ctx.save_for_backward()
保存张量(如果打算在backward
(等同于vjp
)中使用),或者使用ctx.save_for_forward()
保存张量(如果打算在jvp
中使用)。- 返回类型