torch.autograd.function.FunctionCtx.mark_non_differentiable#
- FunctionCtx.mark_non_differentiable(*args)[source]#
将输出标记为不可微分。
此函数最多应被调用一次,可在
setup_context()
或forward()
方法中调用,且所有参数都应为张量输出。这将把输出标记为不需要梯度,从而提高反向传播的效率。您仍然需要在
backward()
中为每个输出接受一个梯度,但它始终会是一个形状与相应输出形状相同的零张量。- 此功能用于例如从排序返回的索引。请参阅示例:
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x): >>> sorted, idx = x.sort() >>> ctx.mark_non_differentiable(idx) >>> ctx.save_for_backward(x, idx) >>> return sorted, idx >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): # still need to accept g2 >>> x, idx = ctx.saved_tensors >>> grad_input = torch.zeros_like(x) >>> grad_input.index_add_(0, idx, g1) >>> return grad_input