评价此页

扩展 PyTorch#

创建时间: 2017年1月16日 | 最后更新时间: 2025年5月7日

在本笔记中,我们将介绍扩展 torch.nntorch.autogradtorch 的方法,以及编写自定义 C++ 扩展。

添加新运算符#

PyTorch 提供了大量的运算符库,可以对张量(Tensor)进行操作(例如 torch.add()torch.sum() 等)。然而,您可能希望为 PyTorch 添加一个新的自定义操作,并使其行为类似于 PyTorch 的内置运算符。为了做到这一点,您必须通过 Python torch.library 或 C++ TORCH_LIBRARY API 将自定义操作注册到 PyTorch。

有关更多详细信息,请参阅 PyTorch 自定义运算符登录页

扩展 torch.autograd#

autograd 添加操作需要为每个操作实现一个新的 Function 子类。回想一下,Functions 是 autograd 用来编码操作历史和计算梯度的。

本指南的第一部分侧重于后向模式 AD,因为它是最广泛使用的功能。最后一部分讨论了前向模式 AD 的扩展。

何时使用#

通常,如果您想在模型中执行计算,而这些计算是不可微的或依赖于非 PyTorch 库(例如 NumPy),但仍希望您的操作与其他操作链接并与 autograd 引擎一起工作,则应实现自定义函数。

在某些情况下,自定义函数还可以用于提高性能和内存使用:如果您使用 C++ 扩展 实现前向和后向传递,则可以将它们包装在 Function 中以与 autograd 引擎进行接口。如果您想减少为后向传递保存的缓冲区数量,则可以使用自定义函数将操作组合在一起。

何时不使用#

如果您的函数已经可以用 PyTorch 的内置操作来编写,那么它的后向图(很可能)已经可以被 autograd 记录。在这种情况下,您无需自己实现后向函数。可以考虑使用普通的 Python 函数。

如果您需要维护状态(即可训练参数),则应(也)使用自定义模块。有关扩展 torch.nn 的更多信息,请参阅下面的部分。

如果您想在后向传递过程中修改梯度或执行副作用,请考虑注册一个 张量模块 钩子。

如何使用#

请执行以下步骤:1. 继承 Function 并实现 forward()、(可选)setup_context()backward() 方法。2. 调用 ctx 参数上的正确方法。3. 声明您的函数是否支持 双后向。4. 使用 gradcheck 验证您的梯度是否正确。

步骤 1: 继承 Function 后,您需要定义 3 个方法

  • forward() 是执行操作的代码。它可以接受任意数量的参数,其中一些参数是可选的(如果您指定了默认值)。此处接受所有种类的 Python 对象。Tensor 参数(如果其 requires_grad=True 以跟踪历史)将在调用前转换为不跟踪历史的张量,并且其使用将被记录在图中。请注意,此逻辑不会遍历列表/字典/任何其他数据结构,只会考虑作为直接参数调用的张量。您可以返回单个 Tensor 输出,或者如果存在多个输出,则返回张量的 tuple。此外,请参考 Function 的文档,以查找只能从 forward() 中调用的有用方法的说明。

  • setup_context()(可选)。您可以编写一个“组合式”的 forward(),它接受一个 ctx 对象,或者(从 PyTorch 2.0 开始)一个不接受 ctx 的独立 forward() 方法和一个 setup_context() 方法,其中 ctx 的修改发生在其中。forward() 应该包含计算逻辑,而 setup_context() 应该只负责 ctx 的修改(而不应包含任何计算)。通常,独立的 forward()setup_context() 更接近 PyTorch 原生操作的工作方式,因此与各种 PyTorch 子系统更具可组合性。有关更多详细信息,请参阅 组合式或独立的 forward() 和 setup_context()

  • backward()(或 vjp())定义了梯度公式。它将接收与输出数量相同的 Tensor 参数,每个参数代表相对于该输出的梯度。重要的是永远不要原地修改它们。它应该返回与输入数量相同的张量,每个张量包含相对于其对应输入的梯度。如果您的输入不需要梯度(needs_input_grad 是一个布尔元组,指示每个输入是否需要梯度计算),或者是非 Tensor 对象,则可以返回 python:None。此外,如果 forward() 有可选参数,您可以返回比输入更多的梯度,只要它们都是 None

步骤 2: 您有责任正确使用 ctx 中的函数,以确保新的 Function 能与 autograd 引擎正常工作。

  • save_for_backward() 应用于保存后向传递所需的任何张量(与直接保存在 ctx 中的不同)。您不能将 save_for_backward 用于非张量;您应该直接将它们存储在 ctx 中。

    通过 save_for_backward 保存张量:1. 允许 autograd 引擎在 autograd.Function 的后向计算完成后立即清除它们。(如果张量直接保存在 ctx 中,它将不必要地在 autograd 图的整个生命周期内保持活动状态——通常直到迭代结束。)2. 有助于避免某些引用循环(例如,因为 autograd.Function 的张量输出本身会保留对 ctx 的引用)。3. 对于与激活检查点和卸载等功能(依赖于 torch.autograd.graph.saved_tensors_hooks)的兼容性很重要。

    如果保存的张量既不是输入也不是输出,您的 Function 可能不支持双后向(请参阅步骤 3)。

  • mark_dirty() 必须用于标记在正向函数中被原地修改的任何输入。

  • mark_non_differentiable() 必须用于告知引擎输出是否不可微分。默认情况下,所有可微分类型的输出张量都将设置为需要梯度。不可微分类型的张量(即整数类型)永远不会被标记为需要梯度。

  • set_materialize_grads() 可用于告诉 autograd 引擎在输出不依赖于输入的情况下优化梯度计算,方法是不具体化传递给后向函数的 grad 张量。也就是说,如果设置为 False,Python 中的 None 对象或 C++ 中的“未定义张量”(tensor x,其中 x.defined() 为 False)将不会在调用后向函数之前转换为填充零的张量,因此您的代码将需要像处理填充零的张量一样处理这些对象。此设置的默认值为 True。

步骤 3: 如果您的 Function 不支持双后向,您应该通过用 once_differentiable() 装饰后向函数来明确声明这一点。使用此装饰器,尝试通过您的函数进行双后向将产生错误。有关双后向的更多信息,请参阅我们的双后向教程。

步骤 4: 建议您使用 torch.autograd.gradcheck() 来检查您的后向函数是否通过计算您的后向函数的雅可比矩阵并将其与通过有限差分数值计算的雅可比矩阵逐元素进行比较,从而正确计算了其前向函数的梯度。

示例#

下面是一个 Linear 函数的代码示例,带有附加注释。

# Inherit from Function
class LinearFunction(Function):

    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(input, weight, bias):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

现在,为了使这些自定义操作更容易使用,我们建议对其进行别名或将其包装在函数中。包装在函数中使我们能够支持默认参数和关键字参数。

# Option 1: alias
linear = LinearFunction.apply

# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
    return LinearFunction.apply(input, weight, bias)

在这里,我们提供了另一个由非张量参数参数化的函数的示例。

class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        # ctx is a context object that can be used to stash information
        # for backward computation
        tensor, constant = inputs
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

在这里,我们通过调用 set_materialize_grads(False) 来优化上述示例。

class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        tensor, constant = inputs
        ctx.set_materialize_grads(False)
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # Here we must handle None grad_output tensor. In this case we
        # can skip unnecessary computations and just return None.
        if grad_output is None:
            return None, None

        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

如果您需要在 forward() 中计算的任何“中间”张量被保存,那么它们必须被返回为输出,或者组合 forwardsetup_context()(请参阅 组合式或独立的 forward() 和 setup_context())。请注意,这意味着如果您希望梯度流经这些中间值,您需要为它们定义梯度公式(另请参阅 双后向教程)。

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        # We wish to save dx for backward. In order to do so, it must
        # be returned as an output.
        dx = 3 * x ** 2
        result = x ** 3
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`,
        # which is grad_dx * 6 * x.
        result = grad_output * dx + grad_dx * 6 * x
        return result

# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

注意

传递给 backward 的输入,即 grad_output,也可以是跟踪历史的张量。因此,如果 backward 是使用可微分操作实现的(例如,调用另一个自定义 Function),则高阶导数将起作用。在这种情况下,使用 save_for_backward 保存的张量也可以在后向中使用,并且具有回传的梯度,但保存在 ctx 中的张量将没有梯度回传。如果您需要为保存在 ctx 中的张量回传梯度,您应该将其作为自定义 Function 的输出,并使用 save_for_backward 保存。

您可能需要检查实现的后向方法是否实际计算了函数的导数。可以通过使用小的有限差分进行数值逼近来比较。

from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)

有关有限差分梯度比较的更多详细信息,请参阅 数值梯度检查。如果您的函数用于高阶导数(对后向传递进行微分),则可以使用同一包中的 gradgradcheck 函数来检查高阶导数。

组合式或独立的 forward()setup_context()#

定义 Function 有两种主要方式。要么

  • 定义一个 forward(),它将前向计算逻辑与 setup_context() 结合起来

  • (从 PyTorch 2.0 开始)定义一个独立的 forward()setup_context()

我们推荐第二种选项(独立的 forward()setup_context()),因为这更接近 PyTorch 原生操作的实现方式,并且与 torch.func 转换具有可组合性。然而,我们计划在未来同时支持这两种方法;将 forward()setup_context() 结合:可以实现更大的灵活性,因为您可以在不将中间结果返回为输出的情况下保存它们。

请参阅上一节,了解如何使用独立的 forward()setup_context() 来定义 Function

以下是如何使用组合式 forward()setup_context() 来定义 Function 的示例。

class LinearFunction(Function):
    @staticmethod
    # ctx is the first argument to forward
    def forward(ctx, input, weight, bias=None):
        # The forward pass can use ctx.
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

前向模式 AD#

覆盖前向模式 AD 公式具有非常相似的 API,但存在一些细微差别。您可以实现 jvp() 函数。

它将接收与输入数量相同的 Tensor 参数,每个参数代表相对于该输入的梯度。它应该返回与输出数量相同的张量,每个张量包含相对于其对应输出的梯度。jvp() 将在 forward() 方法调用之后、apply() 返回之前调用。

jvp()backward() 函数有一些细微的差别。

  • 您可以使用 ctxforward() 中的任何数据传递到 jvp() 函数。如果这些状态不需要用于 backward(),您可以通过在 jvp() 函数的末尾执行 del ctx.foo 来显式释放它。

  • jvp() 的实现必须是后向可微的,或者明确检查是否所有前向模式梯度都设置了 requires_grad

  • jvp() 函数必须匹配 forward() 的视图/原地修改行为。例如,如果第 i 个输入被原地修改,那么第 i 个梯度也必须被原地修改。类似地,如果第 j 个输出是第 k 个输入的视图,那么返回的第 j 个输出梯度也必须是第 k 个输入梯度的视图。

  • 因为用户无法指定需要计算哪个梯度,所以 jvp() 函数应该始终计算所有输出的梯度。

  • 前向模式梯度确实尊重 set_materialize_grads() 设置的标志,并且当该标志被禁用时,您可以获得 None 的输入梯度。

torch.func 转换和/或 torch.vmap()#

有关详细信息,请参阅 使用 autograd.Function 扩展 torch.func

扩展 torch.nn#

nn 导出两种接口——模块及其函数式版本。您可以通过这两种方式进行扩展,但我们建议将模块用于所有类型的层(包含任何参数或缓冲区),并建议将函数式形式用于无参数操作,如激活函数、池化等。

在上面关于添加函数式版本的操作的部分已经完全覆盖了。

添加 Module#

由于 nn 大量使用了 autograd,添加一个新的 Module 需要实现一个执行操作并能计算梯度的 Function。从现在开始,我们假设我们想实现一个 Linear 模块,并且我们已经实现了如上一个列表所示的函数。添加这个只需要很少的代码。现在,有两个函数需要实现:

  • __init__ (可选) - 接受诸如卷积核大小、特征数等参数,并初始化参数和缓冲区。

  • forward() - 实例化一个 Function 并使用它来执行操作。这与上面所示的函数式包装器非常相似。

下面是如何实现一个 Linear 模块

class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)

        # Not a very smart way to initialize weights
        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return LinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        )

扩展 torch Python API#

你可以通过定义一个具有与 Tensor 匹配的方法的自定义类来创建模拟 Tensor 的自定义类型。但是,如果你想能够将这些类型传递给像顶层 torch 命名空间中接受 Tensor 操作数的函数,例如 torch.add(),该怎么办?

如果你的自定义 Python 类型定义了一个名为 __torch_function__ 的方法,当你的自定义类的实例被传递给 torch 命名空间中的函数时,PyTorch 将调用你的 __torch_function__ 实现。这使得为 __torch_function__ 实现可以调用的 torch 命名空间中的任何函数定义自定义实现成为可能,从而允许你的用户使用他们已经为 Tensor 编写的现有 PyTorch 工作流程来利用你的自定义类型。这适用于与 Tensor 无关的“鸭子类型”以及 Tensor 的用户定义子类。

使用类 Tensor 的类型扩展 torch#

注意

此功能受到 NumPy __array_function__ 协议的启发。有关更多详细信息,请参阅 NumPy 文档NEP-0018

为了具体说明,让我们从一个简单的例子开始,该例子说明了 API 分派机制。我们将创建一个表示二维标量张量的自定义类型,由阶数 N 和对角线项上的值 value 参数化。

class ScalarTensor(object):
   def __init__(self, N, value):
       self._N = N
       self._value = value

   def __repr__(self):
       return "ScalarTensor(N={}, value={})".format(self._N, self._value)

   def tensor(self):
       return self._value * torch.eye(self._N)

该设计的第一个迭代用处不大。 ScalarTensor 的主要功能是提供比基本张量类更紧凑的标量张量字符串表示。

>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 0., 0.],
        [0., 0., 0., 2., 0.],
        [0., 0., 0., 0., 2.]])

如果我们尝试在 torch API 中使用此对象,我们将遇到问题。

>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor

ScalarTensor 添加 __torch_function__ 实现可以使上述操作成功。让我们重新进行实现,这次添加一个 __torch_function__ 实现。

HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
    def __init__(self, N, value):
        self._N = N
        self._value = value

    def __repr__(self):
        return "ScalarTensor(N={}, value={})".format(self._N, self._value)

    def tensor(self):
        return self._value * torch.eye(self._N)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

__torch_function__ 方法接受四个参数:func,指向正在重写的 torch API 函数的引用;types,实现 __torch_function__ 的 Tensor-like 类型列表;args,传递给函数的参数元组;以及 kwargs,传递给函数的关键字参数字典。它使用名为 HANDLED_FUNCTIONS 的全局分派表来存储自定义实现。此字典的键是 torch 命名空间中的函数,值是 ScalarTensor 的实现。

注意

使用全局分派表不是 __torch_function__ API 的强制性部分,它只是用于构建重写实现的一种有用的设计模式。

这个类定义不足以让 torch.mean 在我们传递 ScalarTensor 时执行正确操作——我们还需要为 ScalarTensor 操作数定义 torch.mean 的实现,并将实现添加到 HANDLED_FUNCTIONS 分派表字典中。一种方法是定义一个装饰器。

import functools
def implements(torch_function):
    """Register a torch function override for ScalarTensor"""
    def decorator(func):
        functools.update_wrapper(func, torch_function)
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

该装饰器可以应用于我们重写的实现。

@implements(torch.mean)
def mean(input):
    return float(input._value) / input._N

有了这个更改,我们现在就可以在 ScalarTensor 中使用 torch.mean

>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4

当然,torch.mean 是最简单的重写函数类型的一个例子,因为它只有一个操作数。我们可以使用相同的机制来重写一个接受多个操作数的函数,其中任何一个操作数都可能是定义了 __torch_function__ 的张量或类张量,例如对于 torch.add()

def ensure_tensor(data):
    if isinstance(data, ScalarTensor):
        return data.tensor()
    return torch.as_tensor(data)

@implements(torch.add)
def add(input, other):
   try:
       if input._N == other._N:
           return ScalarTensor(input._N, input._value + other._value)
       else:
           raise ValueError("Shape mismatch!")
   except AttributeError:
       return torch.add(ensure_tensor(input), ensure_tensor(other))

此版本在两个操作数都是 ScalarTensor 实例时有一个快速路径,并且还有一个较慢的路径,当任何一个操作数不是 ScalarTensor 时,该路径会退化为将数据转换为张量。这使得重写函数在任一操作数是 ScalarTensor 或常规 Tensor 时都能正确工作。

>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
        [1., 3.]])

请注意,我们对 add 的实现不接受 alphaout 作为关键字参数,而 torch.add() 则接受。

>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha'

为了提高速度和灵活性,__torch_function__ 分派机制不会检查重写函数的签名是否与 torch API 中要重写的函数的签名相匹配。对于某些应用程序,忽略可选参数是可以的,但为了确保与 Tensor 的完全兼容性,torch API 函数的用户实现应注意精确地模拟要重写的函数的 API。

torch API 中的函数如果没有显式重写,将从 __torch_function__ 返回 NotImplemented。如果所有具有 __torch_function__ 定义的实例的操作数都返回 NotImplemented,PyTorch 将引发 TypeError。这意味着大多数时候,对于某个类型没有显式重写的操作,当传递该类型的一个实例时,将引发 TypeError

>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor]

实际上,这意味着如果您想使用 __torch_function__ 实现来执行重写,您将需要显式实现完整的 torch API 或您关心的 API 的整个子集。这可能是一项艰巨的任务,因为完整的 torch API 非常庞大。

另一种选择是不为未处理的操作返回 NotImplemented,而是当没有重写可用时,将 Tensor 传递给原始的 torch 函数。例如,如果我们修改 ScalarTensor__torch_function__ 实现如下:

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
        args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
        return func(*args, **kwargs)
    return HANDLED_FUNCTIONS[func](*args, **kwargs)

那么 torch.mul() 将正常工作,尽管即使两个操作数都是 ScalarTensor 实例,返回类型也将始终是 Tensor 而不是 ScalarTensor

>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
        [0., 4.]])

另请参阅下面的 MetadataTensor 示例,了解此模式的另一种变体,但它始终返回 MetadataTensor 以通过 torch API 中的操作传播元数据。

__torch_function__ 协议旨在全面覆盖 API,部分覆盖可能会导致不良结果,特别是某些函数引发 TypeError。这对于子类尤其如此,其中 torch.addtorch.Tensor.__add__torch.Tensor.add 三者都必须覆盖,即使它们返回完全相同的结果。未能做到这一点也可能导致无限递归。如果有人需要实现 torch.Tensor 子类的函数,他们必须在实现中使用 super().__torch_function__

子类化 torch.Tensor#

从版本 1.7.0 开始,应用于 torch.Tensor 子类的方法以及公共 torch.* 命名空间中的函数将返回子类实例而不是 torch.Tensor 实例。

>>> class SubTensor(torch.Tensor):
...     pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'

如果存在多个子类,默认情况下将选择层次结构中最低的子类。如果没有唯一的方法来确定这种情况,则会引发 TypeError

>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]

如果希望对所有张量方法进行全局重写,可以使用 __torch_function__。下面是一个记录所有函数/方法调用的示例。

class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
        if func is not torch.Tensor.__repr__:
            logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

但是,如果想重写 Tensor 子类上的方法,可以通过直接重写方法(通过为子类定义它)或使用 __torch_function__ 并与 func 匹配来实现。

对于子类中的 __torch_function__,应该始终调用 super().__torch_function__(func, ...) 而不是直接调用 func,这与 1.7.0 版本之前的 func 情况相同。未能这样做可能会导致 func 递归回 __torch_function__,从而导致无限递归。

使用类 Tensor 的包装器类型扩展 torch#

另一个有用的场景是包装 Tensor 的类型,无论是作为属性还是通过子类化。下面我们实现这种类型的一个特例,即 MetadataTensor,它将元数据字典附加到 Tensor 上,并将其通过 torch 操作进行传播。由于这是对完整 torch API 的通用包装,我们不需要单独实现每个重写,因此我们可以使 __torch_function__ 实现对允许的操作更加宽松。

class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._metadata = metadata

    def __repr__(self):
        return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
        args = [getattr(a, '_t', a) for a in args]
        assert len(metadatas) > 0
        ret = func(*args, **kwargs)
        return MetadataTensor(ret, metadata=metadatas[0])

这个简单的实现不一定适用于 torch API 中的每个函数,但它足以捕获大多数常见操作。

>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[2, 4],
        [4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[1, 4],
        [3, 8]])

具有 __torch_function__ 实现的多个类型上的操作#

可以使用具有 __torch_function__ 实现的多个不同类型来使用 torch API,但这需要格外小心。在这种情况下,规则是:

  • 分派操作会收集每个操作数的 __torch_function__ 的所有不同实现,并按顺序调用它们:子类优先于超类,否则按照运算符表达式的从左到右顺序。

  • 如果返回 NotImplemented 以外的任何值,则将该值作为结果返回。实现者可以通过返回 NotImplemented 来表明它们不实现某个操作。

  • 如果所有 __torch_function__ 实现都返回 NotImplemented,PyTorch 将引发 TypeError

PyTorch API 重写覆盖范围测试#

实现 __torch_function__ 的一个令人头疼的方面是,如果某些操作有重写而另一些没有,用户充其量会看到不一致的体验,最坏的情况下,在使用没有重写的函数时会遇到运行时错误。为了简化这个过程,PyTorch 提供了一个面向开发者的 API 来确保对 __torch_function__ 重写的全面支持。这个 API 是私有的,可能会在未来未经通知而更改。

首先,要获取所有可重写函数的列表,请使用 torch.overrides._get_overridable_functions。这将返回一个字典,其键是 PyTorch Python API 中的命名空间,其值是该命名空间中可以重写的函数列表。例如,让我们打印 torch.nn.functional 中前 5 个可重写函数的名称。

>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
 'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']

这个函数列表使得迭代所有可重写的函数成为可能,但在实践中,这不足以编写针对所有这些函数的测试,而不必费力地手动复制每个函数的签名来进行每个测试。为了简化这个过程,torch.overrides._get_testing_overrides 函数返回一个字典,该字典将 PyTorch API 中的可重写函数映射到虚拟 lambda 函数,这些函数具有与原始函数相同的签名,但无条件返回 -1。这些函数最适合与 inspect 一起使用,以分析原始 PyTorch 函数的函数签名。

>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>

最后,torch.overrides.get_ignored_functions 返回一个元组,其中包含明确不能被 __torch_function__ 重写的函数。这个列表对于确认 get_overridable_functions 返回的字典中不存在的函数不能被重写很有用。

扩展 torch 原生 API#

虽然 __torch_function__ 允许我们有效地扩展 PyTorch 的纯 Python 组件的行为,但它不允许我们扩展 PyTorch 中用 C++ 实现的部分。为此,Tensor 的子类也可以定义 __torch_dispatch__,它可以在 C++ 级别重写行为。

为了有效使用此功能,了解 PyTorch 的原生部分是如何实现的很重要。其中最重要的组件是我们称之为“分派器”的东西(最好的描述可以在这篇 博客文章 中找到,尽管它已略微过时)。正如其名称所示,它负责为函数的特定调用调用正确的后端函数。例如,当调用 torch.add(a, b) 时,分派器将检查两个参数,确定为这次特定调用应该使用哪个“功能”(autograd、autocast、functionalization 等)和哪个“后端”(CPU、CUDA、MPS 等),最后调用所有正确的内核。内核的一个非常常见的操作是“重新分派”。例如,当在 GPU 上使用 autocast 运行神经网络时,第一次调用将是 autocast 内核,它将处理任何潜在的 autocast 逻辑并向下重新分派。下一项功能是 autograd,它将正确创建 autograd 图,然后向下重新分派。最后,我们到达 CUDA 的后端内核,它将启动正确的 CUDA 内核并返回最终结果。在退出时,autograd 将图附加到输出,最后,autocast 将有机会在退出时进行任何必要的更新。

分派器的一个配置是所有这些功能和后端键的调用顺序。最新的列表及其顺序可以在 DispatchKey.h 中的 DispatchKey 枚举中找到。为了扩展 torch 的目的,该讨论中最重要的排序子集是:

vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> Backends

为了本讨论的目的,最重要的键是 Python,因为每个定义了 __torch_dispatch__ 方法的 Tensor 子类都将调用此功能。从那里调用用户定义的函数,并且可以任意重写行为。从那里,再次调用提供的 func 将执行“重新分派”。

此实现的一些重要含义是:

  • 此代码在“所有功能之下”运行。因此,它像常规后端一样,仅负责生成每个 Tensor 的输出值(并且可以,也应该忽略 autograd、autocast 等所有高级功能)。

  • 如果任何高级功能实现了给定的函数而没有重新分派,它将永远不会到达 Python 键,因此 __torch_dispatch__ 回调将永远不会被触发。这尤其发生在 CompositeImplicitAutograd 函数上,这些函数在 Autograd 级别进行计算而不进行重新分派。这是因为 CompositeImplicitAutograd 函数通过隐式调用其他原生操作来指定其 autograd 公式,因此在 Autograd 级别,该函数被分解为其原生操作,然后进行计算。

  • 在回调到 Python 和包装结果时,使用的是与常规 PyTorch Python/C++ 绑定相同的转换。特别是,某些对象无法在 Python 中表示,需要特殊处理(例如,未定义的 Tensor 会变成 None)。

  • 我们的原生函数被惰性地填充为 torch.ops.{namespace}.{func_name}.{overload_name},作为可调用的 Python 对象,以便于从 Python 进行交互。传递给 __torch_dispatch__func 对象始终是此命名空间中的一个条目。此命名空间可用于直接调用原生操作并绕过常规 Python API 和绑定代码。

就像 __torch_function__ 能够干预 torch 的所有 Python API 和 Tensor 方法一样,__torch_dispatch__ 能够拦截所有对 aten 原生 API 的调用。请注意,Tensor 上的所有方法在进入分派器之前都会转换为函数调用,因此在这里会显示为函数调用:torch.add(a, 2)a + 2 将导致完全相同的 aten 调用。其中大多数函数定义在 native_functions.yaml 中,该文件指定了这些函数的属性及其后端实现。它们的实现以及指定的特性通过 codegen 自动注册。一些更奇特的函数或特性也在 C++ 代码库的其他地方或用户定义的 C++ 扩展中注册。

还可以使用 torch.library 添加新的原生函数。这个 Python 功能允许定义和/或向原生函数添加新实现。这可用于添加缺失的内核、替换现有的内核或定义全新的原生函数。

您可以在 subclass_zoo 仓库中找到许多基于 __torch_dispatch__ 的子类的示例。

__torch_dispatch__ 调用约定#

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
    pass

当用户调用具有 __torch_dispatch__ 的输入的运算符时,该调用可能会被转发到 __torch_dispatch__。在调用 __torch_dispatch__ 之前,args 和 kwargs 会被规范化,即:

  • kwargs 由运算符模式中的关键字参数组成。如果某个关键字参数与其在模式中的默认值相等,则不会传递它。

  • args 由所有其他参数组成,无论它们是如何传递给运算符的(位置参数 vs 关键字参数)。如果某个参数等于其默认值,并且它是最右边的位置参数,或者它右边的所有参数都没有传递,那么它将不会被传递。

使用模式扩展所有 torch API#

不幸的是,有些函数不接受 Tensor 输入。这意味着上述子类方法不能用于重写 PyTorch 所有函数的行为。此外,如果用例需要拦截每个函数调用,将每个 Tensor 更改为子类可能会过于侵入。

为了解决这个用例,我们引入了“模式”的概念。这些模式存在于 __torch_function____torch_dispatch__ 重写中,分别通过继承 torch.overrides.TorchFunctionModetorch.utils._python_dispatch.TorchDispatchMode 来创建,并用作上下文管理器。

为了简化描述它如何与子类和其他模式交互,每当进入模式的上下文管理器时,所有函数都表现得好像在参数列表的开头有一个额外的 Tensor 参数,其中模式是子类。这意味着,特别是所有模式处理程序都将在任何子类处理程序之前被调用,并且对应于内部上下文管理器的模式将始终首先运行。

还值得注意的是,在给定的模式处理程序中,此特定模式被禁用,可以通过执行 with self: 来手动重新启用。

这是一个展示各种日志模式的示例

import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode

class FunctionLog(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

class DispatchLog(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

def f():
    a = torch.rand(10, requires_grad=True)
    b = a * 2
    b.sum().backward()

print("TorchFunctionMode logging:")
with FunctionLog():
    f()

print("TorchDispatchMode logging:")
with DispatchLog():
    f()

这将打印出以下内容,并附带额外注释

TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
        0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
        1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})

TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
        0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
        1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})