评价此页

torch.Tensor.register_post_accumulate_grad_hook#

Tensor.register_post_accumulate_grad_hook(hook)[源码]#

注册一个在梯度累积后运行的反向钩子。

该 hook 将在所有梯度累积到某个张量之后被调用,这意味着该张量的 `.grad` 字段已被更新。post accumulate grad hook **仅**适用于叶子张量(即没有 `.grad_fn` 字段的张量)。在非叶子张量上注册此 hook 会报错!

钩子应具有以下签名

hook(param: Tensor) -> None

请注意,与其他 autograd hook 不同,此 hook 操作的是需要梯度的张量本身,而不是梯度。hook 可以原地修改和访问其张量参数,包括其 `.grad` 字段。

此函数返回一个句柄,其中包含一个方法 handle.remove(),用于从模块中移除该钩子。

注意

有关此 hook 何时执行以及其执行顺序相对于其他 hook 的更多信息,请参阅 Backward Hooks execution。由于此 hook 在反向传播过程中运行,它将在 `no_grad` 模式下运行(除非 `create_graph` 为 `True`)。如果您需要在 hook 中重新启用 autograd,可以使用 `torch.enable_grad()`。

示例

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> lr = 0.01
>>> # simulate a simple SGD update
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)

>>> h.remove()  # removes the hook