将 torch.func 与 autograd.Function 扩展#
创建于: 2023年01月03日 | 最后更新于: 2023年09月14日
您想将 torch.autograd.Function 与 torch.func 变换(如 torch.vmap()、torch.func.grad() 等)一起使用。
有两种主要用例:
- 您希望调用不包含 PyTorch 操作的代码,并使其能够与函数变换一起工作。也就是说, - torch.autograd.Function的 forward/backward/etc 调用会指向其他系统(如 C++、CUDA、NumPy)的函数。
- 您希望指定自定义梯度规则,类似于 JAX 的 custom_vjp/custom_jvp。 
PyTorch 将这两个概念结合到了 torch.autograd.Function 中。
基本用法#
本指南假设您已熟悉 扩展 torch.autograd,其中解释了如何使用 torch.autograd.Function。
torch.autograd.Function 可以有一个接受 ctx 对象的 forward() 方法,或者可以有单独的 forward() 方法(不接受 ctx)和一个 setup_context() 静态方法,后者会修改 ctx 对象。
只有后者才支持函数变换。
- forward()是执行操作的代码,它不应接受- ctx对象。
- setup_context(ctx, inputs, output)是您可以调用- ctx方法的代码。在这里,您应该保存用于反向传播的 Tensor(通过调用- ctx.save_for_backward(*tensors)),或者保存非 Tensor 对象(通过将它们赋值给- ctx对象)。
由于 setup_context() 只接受 inputs 和 output,因此只能保存输入或输出中的对象(如 Tensor)或从它们派生的数量(如 Tensor.shape)。如果您希望为反向传播保存 Function.forward() 的非输入中间激活,则需要将其作为 forward() 的输出返回,以便传递给 setup_context()。
根据变换的不同:
- 为了支持反向模式 AD( - torch.func.grad()、- torch.func.vjp()),- torch.autograd.Function需要一个- backward()静态方法。
- 为了支持 - torch.vmap(),- torch.autograd.Function需要一个- vmap()静态方法。
- 为了支持 - torch.func.jvp(),- torch.autograd.Function需要一个- jvp()静态方法。
- 为了支持变换的组合(如 - torch.func.jacrev()、- torch.func.jacfwd()、- torch.func.hessian())——您可能需要以上多种方法。
为了使 torch.autograd.Function 能够与函数变换任意组合,我们建议除 forward() 和 setup_context() 之外的所有其他静态方法都必须是可变换的:也就是说,它们必须只由 PyTorch 操作组成,或者调用其他 torch.autograd.Function(这些 torch.autograd.Function 可能调用 C++/CUDA/etc)。
下面我们来看一些常见用例的示例。
示例 1:autograd.Function 调用另一个系统#
一种常见情况是 torch.autograd.Function 同时在 forward() 和 backward() 中调用另一个系统(如 C++、CUDA、NumPy、Triton)。
import torch
import numpy as np
def to_numpy(tensor):
    return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
    # Note that forward does not take ctx
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        # Any intermediates to be saved in backward must be returned as
        # outputs.
        return (
            # The desired output
            torch.tensor(result, device=device),
            # intermediate to save for backward
            torch.tensor(ind, device=device),
            # intermediate to save for backward
            torch.tensor(ind_inv, device=device),
        )
    # setup_context is responsible for calling methods and/or assigning to
    # the ctx object. Please do not do additional compute (e.g. add
    # Tensors together) in setup_context.
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        # Note that output is whatever you returned from forward.
        # If you returned multiple values, then output is a Tuple of multiple values.
        # If you returned a single Tensor, then output is a Tensor.
        # If you returned a Tuple with a single Tensor, then output is a
        # Tuple with a single Tensor.
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        # Tensors must be saved via ctx.save_for_backward. Please do not
        # assign them directly onto the ctx object.
        ctx.save_for_backward(ind, ind_inv)
        # Non-tensors may be saved by assigning them as attributes on the ctx object.
        ctx.dim = dim
    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        # For the autograd.Function to be arbitrarily composable with function
        # transforms, all staticmethod other than forward and setup_context
        # must be implemented in a "transformable" way; that is, they must
        # only consist of PyTorch operations or autograd.Function.
        #
        # For example, this allows us to do double backwards and/or compute
        # second order gradients.
        #
        # We've written the backward pass of NumpySort in terms of another
        # autograd.Function, NumpyTake.
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim
    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None
现在,为了更方便地使用 NumpySort(隐藏我们作为输出返回的中间变量,并允许默认的 args 和 kwargs),我们创建了一个新函数来调用它。
def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result
这是一个健全性检查。
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function 指定自定义梯度规则#
另一种常见情况是使用 PyTorch 操作实现的 torch.autograd.Function。PyTorch 能够自动计算 PyTorch 操作的梯度,但我们可能希望自定义梯度的计算方式。我们可能希望自定义 backward 的原因包括:
- 提高数值稳定性 
- 改变 backward 的性能特征 
- 改变边缘情况的处理方式(例如,NaN、Inf) 
- 修改梯度(例如,梯度裁剪) 
下面是一个函数 y = x ** 3 的 torch.autograd.Function 示例,其中我们改变了性能特征(一些通常在 backward 传递中进行的计算,计算 dx,现在在 forward 传递中完成)。
class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        result = x ** 3
        # In regular PyTorch, if we had just run y = x ** 3, then the backward
        # pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
        # that computation here in the forward pass instead.
        dx = 3 * x ** 2
        return result, dx
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)
    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`.
        result = grad_output * dx + grad_dx * 6 * x
        return result
现在,为了更方便地使用 NumpySort(并隐藏我们作为输出返回的中间变量),我们创建了一个新函数来调用它。
def my_cube(x):
    result, _ = MyCube.apply(x)
    return result
这是一个计算二阶梯度的健全性检查。
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和注意事项#
警告
请仔细阅读 torch.autograd.Function 与 torch.func 变换结合使用的限制。我们无法优雅地捕获许多这种情况,它们会导致未定义的行为。
请不要将正在被变换的 Tensor、requires_grad=True 的 Tensor 或双重 Tensor 捕获到 torch.autograd.Function 的方法中。完全安全的方法是确保 torch.autograd.Function 的任何方法内部使用的唯一 Tensor 必须直接作为输入(或通过 ctx 对象)传递,而不是来自 torch.autograd.Function 外部。
torch.autograd.Function 不处理 PyTree 中的 Tensor(可能包含 Tensor 的任意嵌套 Python 数据结构)。为了让这些 Tensor 被 autograd 跟踪,它们必须直接作为参数传递给 torch.autograd.Function。这与 jax.{custom_vjp, custom_jvp} 不同,后者接受 PyTree。
请仅使用 save_for_backward() 或 save_for_forward() 来保存 Tensor。请不要直接将 Tensor 或 Tensor 集合赋值给 ctx 对象——这些 Tensor 将不会被跟踪。
torch.vmap() 支持#
要将 torch.autograd.Function 与 torch.vmap() 一起使用,您必须执行以下操作之一:
- 提供一个 - vmap()静态方法,告诉我们- torch.autograd.Function在- torch.vmap()下的行为。
- 通过设置 - generate_vmap_rule=True来请求我们自动生成它。
自动生成 vmap 规则#
如果您的 torch.autograd.Function 满足以下附加约束,我们就能为其生成 vmap 规则。如果它不满足约束,或者您希望在 vmap 下具有自定义行为,请手动定义 vmap 静态方法(请参阅下一节)。
警告
我们无法轻松检查以下约束并优雅地报错。违反约束可能导致未定义的行为。
- torch.autograd.Function的- forward()、- backward()(如果存在)和- jvp()(如果存在)静态方法必须可以通过- torch.vmap()进行变换。也就是说,它们必须只由 PyTorch 操作组成(而不是例如 NumPy 或自定义 CUDA 内核)。
示例
class MyCube(torch.autograd.Function):
    # Set generate_vmap_rule to True to ask PyTorch to automatically generate
    # a vmap rule.
    generate_vmap_rule = True
    @staticmethod
    def forward(x):
        result = x ** 3
        dx = 3 * x ** 2
        return result, dx
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)
    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        result = grad_output * dx + grad_dx * 6 * x
        return result
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义 vmap 静态方法#
如果您的 torch.autograd.Function 调用了另一个系统(如 NumPy、C++、CUDA、Triton),那么为了使其能够与 torch.vmap() 或使用它的变换一起工作,您需要手动定义一个 vmap() 静态方法。
根据您想要使用的变换以及您的用例,您可能不需要在所有 torch.autograd.Function 中添加 vmap() 静态方法。
- 例如, - torch.func.jacrev()在 backward 传递上执行- vmap()。因此,如果您只对使用- torch.func.jacrev()感兴趣,那么只有- backward()静态方法需要是可 vmap 的。
我们建议确保您的所有 torch.autograd.Function 都支持 torch.vmap(),特别是如果您正在编写第三方库,并希望您的 torch.autograd.Function 能够与所有 torch.func() 变换的组合一起工作。
概念上,vmap 静态方法负责定义 forward() 在 torch.vmap() 下的行为。也就是说,它定义了如何变换 forward() 以在具有附加维度(被 vmap 的维度)的输入上运行。这类似于 torch.vmap() 如何在 PyTorch 操作上实现:对于每个操作,我们定义一个 vmap 规则(有时也称为“批处理规则”)。
以下是如何定义 vmap() 静态方法:
- 签名是 - vmap(info, in_dims: Tuple[Optional[int]], *args),其中- *args与- forward()的 args 相同。
- vmap 静态方法负责定义 - forward()在- torch.vmap()下的行为。也就是说,给定带有附加维度(由- in_dims指定)的输入,我们如何计算- forward()的批处理版本?
- 对于 - args中的每个参数,- in_dims都有一个对应的- Optional[int]。如果参数不是 Tensor,或者参数没有被 vmap,则为- None;否则,它是一个整数,指定 Tensor 被 vmap 的哪个维度。
- info是一个包含附加元数据的集合,这些元数据可能很有用:- info.batch_size指定了被 vmap 的维度的大小,而- info.randomness是传递给- torch.vmap()的- randomness选项。
- vmap 静态方法的返回值是一个元组 - (output, out_dims)。与- in_dims类似,- out_dims的结构应与- output相同,并且每个输出都包含一个- out_dim,指定输出是否具有 vmap 的维度以及在该维度中的索引。
示例
def to_numpy(tensor):
    return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        return (
            torch.tensor(result, device=device),
            torch.tensor(ind, device=device),
            torch.tensor(ind_inv, device=device),
        )
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim
    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
    # The signature of the vmap staticmethod is:
    # vmap(info, in_dims: Tuple[Optional[int]], *args)
    # where *args is the same as the arguments to `forward`.
    @staticmethod
    def vmap(info, in_dims, x, dim):
        # For every input (x and dim), in_dims stores an Optional[int]
        # that is:
        # - None if the input is not being vmapped over or if the input
        #   is not a Tensor
        # - an integer if the input is being vmapped over that represents
        #   the index of the dimension being vmapped over.
        x_bdim, _ = in_dims
        # A "vmap rule" is the logic of how to perform the operation given
        # inputs with one additional dimension. In NumpySort, x has an
        # additional dimension (x_bdim). The vmap rule is simply
        # to call NumpySort again but pass it a different `dim`.
        x = x.movedim(x_bdim, 0)
        # Handle negative dims correctly
        dim = dim if dim >= 0 else dim + x.dim() - 1
        result = NumpySort.apply(x, dim + 1)
        # The vmap rule must return a tuple of two things
        # 1. the output. Should be the same amount of things
        #    as returned by the forward().
        # 2. one Optional[int] for each output specifying if each output
        # is being vmapped over, and if so, the index of the
        # dimension being vmapped over.
        #
        # NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
        # dimension being vmapped over to the front of `x`, that appears at
        # dimension 0 of all outputs.
        # The return is (output, out_dims) -- output is a tuple of 3 Tensors
        # and out_dims is a Tuple of 3 Optional[int]
        return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim
    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None
    @staticmethod
    def vmap(info, in_dims, x, ind, ind_inv, dim):
        x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
        # The strategy is: expand {x, ind, ind_inv} to all have the dimension
        # being vmapped over.
        # Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
        # Handle negative dims by wrapping them to be positive
        logical_dim = x.dim() if x_bdim is None else x_bdim - 1
        dim = dim if dim >= 0 else dim + logical_dim
        def maybe_expand_bdim_at_front(x, x_bdim):
            if x_bdim is None:
                return x.expand(info.batch_size, *x.shape)
            return x.movedim(x_bdim, 0)
        # If the Tensor doesn't have the dimension being vmapped over,
        # expand it out. Otherwise, move it to the front of the Tensor
        x = maybe_expand_bdim_at_front(x, x_bdim)
        ind = maybe_expand_bdim_at_front(ind, ind_bdim)
        ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
        # The return is a tuple (output, out_dims). Since output is a Tensor,
        # then out_dims is an Optional[int] (instead of being a Tuple).
        return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap 静态方法应旨在保留整个 Function 的语义。也就是说,(伪代码)grad(vmap(MyFunc)) 应该可以被 grad(map(MyFunc)) 替换。
如果您的 autograd.Function 在 backward 传递中具有任何自定义行为,请牢记这一点。
torch.func.jvp() 支持#
为了支持前向模式 AD,torch.autograd.Function 必须有一个 jvp() 静态方法。详情请参阅 前向 AD autograd.Function。