评价此页

使用 C++ 为新后端扩展 Dispatcher#

创建于: 2021年02月01日 | 最后更新: 2024年09月23日 | 最后验证: 2024年11月05日

在本教程中,我们将逐步介绍将 Dispatcher 扩展到 PyTorch 仓库外部添加新设备所需的所有必要步骤,并对其进行维护以与原生 PyTorch 设备保持同步。在此,我们假设您熟悉如何在 C++ 中 注册已分派的运算符 以及如何编写 自定义自动微分函数

注意

本教程涉及 PyTorch 内部的许多组件,这些组件正在积极改进中,因此,如果您决定遵循本教程,请预计 API 会发生变化。我们将使本教程与最新的 API 保持同步。

什么是新后端?#

为 PyTorch 添加新后端需要后端扩展者进行大量开发和维护。在添加新后端之前,让我们先考虑一些常见用例和推荐解决方案:

  • 如果您有现有 PyTorch 运算符的新算法,请向 PyTorch 发送 PR。

  • 如果您想提出新的运算符,请向 PyTorch 发送功能请求/PR。

  • 如果您想添加对新设备/硬件(如 Google TPU 和定制芯片)的支持,这通常需要使用硬件特定的 API 来编写内核,请遵循本教程并向 PyTorch 添加一个“out-of-tree”(树外)后端。

  • 如果您想为现有运算符添加支持,但采用不同的 Tensor 布局/表示(如稀疏和量化),这会强制您的内核以一种更有效的方式编写,考虑到布局/表示的限制,请遵循本教程并向 PyTorch 添加一个“out-of-tree”(树外)后端。

在本教程中,我们将主要关注添加一个新的“out-of-tree”(树外)设备。为不同的张量布局添加“out-of-tree”(树外)支持可能与设备共享许多通用步骤,但我们尚未看到此类集成的示例,因此可能需要 PyTorch 额外的工作来支持它。

为您的后端获取 Dispatch Key#

PyTorch 运算符是用 C++ 实现的,并通过 Python 绑定在 Python 前端中可用。PyTorch Dispatcher 将运算符的实现划分为多个内核,每个内核都与特定的 Dispatch Key 相关联。在 PyTorch 中支持新后端本质上意味着用 C++ 为每个 PyTorch 运算符编写一个内核,然后将它们注册到一个代表您定制后端的 Dispatch Key 中。

Dispatch Key 是您在 Dispatcher 系统中的标识符。Dispatcher 查看输入 Tensor 上携带的 Dispatch Key,并相应地调用正确的内核。PyTorch 提供了三个保留的 Dispatch Key(及其对应的 Autograd Key),用于原型化“out-of-tree”(树外)后端扩展:

  • PrivateUse1/AutogradPrivateUse1

  • PrivateUse2/AutogradPrivateUse2

  • PrivateUse3/AutogradPrivateUse3

您可以选择以上任何一个 Key 来原型化您的定制后端。要创建一个 Tensor 在 PrivateUse1 后端上,您需要在 TensorImpl 构造函数中设置 Dispatch Key。

/* Example TensorImpl constructor */
TensorImpl(
    Storage&& storage,
    DispatchKeySet ks,
    const caffe2::TypeMeta data_type);

// To create a TensorImpl on PrivateUse1 backend, pass in the following ks to TensorImpl creation.
DispatchKeySet ks = c10::DispatchKeySet{c10::DispatchKey::PrivateUse1, c10::DispatchKey::AutogradPrivateUse1};

请注意,上面的 TensorImpl 类假定您的 Tensor 由类似 CPU/CUDA 的存储支持。我们还为没有存储的后端提供了 OpaqueTensorImpl。您可能需要调整/覆盖某些方法以适应您的定制硬件。PyTorch 仓库中的一个示例是 Vulkan TensorImpl

注意

一旦原型完成,并且您计划为后端扩展进行定期发布,请随时提交 PR 到 pytorch/pytorch 以为您的后端保留专用的 Dispatch Key。

获取 PyTorch 运算符的完整列表#

PyTorch 在生成的 C++ 文件 build/aten/src/ATen/RegistrationDeclarations.h 中提供了可扩展的 C++ 运算符的完整列表。此文件仅在从源代码构建 PyTorch 后可用。这是一个文件片段:

Tensor abs(const Tensor & self); // {"schema": "aten::abs(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
Tensor & abs_(Tensor & self); // {"schema": "aten::abs_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
Tensor & abs_out(Tensor & out, const Tensor & self); // {"schema": "aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor absolute(const Tensor & self); // {"schema": "aten::absolute(Tensor self) -> Tensor", "dispatch": "False", "default": "False"}
Tensor & absolute_(Tensor & self); // {"schema": "aten::absolute_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "False"}
Tensor & absolute_out(Tensor & out, const Tensor & self); // {"schema": "aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "False"}
Tensor angle(const Tensor & self); // {"schema": "aten::angle(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
Tensor & angle_out(Tensor & out, const Tensor & self); // {"schema": "aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor sgn(const Tensor & self); // {"schema": "aten::sgn(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}

单个运算符有多个相关字段。让我们以 abs_out 为例来分解它:

  • Tensor & abs_out(Tensor & out, const Tensor & self); 是运算符的 C++ 签名,您的 C++ 内核应完全匹配此签名。

  • aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 是表示运算符的唯一模式(schema),它还包含与 C++ 签名相比的别名和突变(mutation)注释。这是 Dispatcher 用来查找运算符的唯一标识符。

  • dispatchdefault 是布尔字段,提供了有关原生 PyTorch 内核可以执行的操作的信息,从而暗示后端扩展者是否需要实现内核。更多详细信息请参阅 为新后端注册内核

为新后端注册内核#

要将您的内核注册到 PyTorch Dispatcher,您可以使用 在 C++ 中注册已分派的运算符 中描述的 TORCH_LIBRARY_IMPL API。

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
  m.impl(<schema_my_op1>, &my_op1);
  m.impl(<schema_my_op2>, &my_op2);
  m.impl(<schema_my_op2_backward>, &my_op2_backward);
}

现在让我们深入了解哪些运算符需要定制后端内核,以及内核中具体包含什么。

PyTorch 目前拥有超过 1600 个运算符,并且仍在增长。后端扩展者很难跟上这个速度。即使是 CPU 或 CUDA 等原生后端,为每个新运算符编写专用内核通常也需要大量工作。

幸运的是,一些原生 PyTorch 内核的编写方式是将它们分解为已知运算符的组合。换句话说,您只需要实现一组已知的运算符(下面需要注册的运算符),而不是所有 PyTorch 运算符。

PyTorch 运算符可分为两类:

  • 需要注册的运算符:这些运算符的原生 PyTorch 实现是特定于后端的,因此必须为定制后端提供内核。否则,在定制后端上调用此类运算符将出错。

    • RegistrationDeclarations.h 中,这些运算符在其伴随注释的元数据中具有 dispatch 设置为 True *并且* default 设置为 False。

  • 可选注册:后端扩展者可以跳过对这些运算符的注册,而不会牺牲任何支持。但是,如果后端扩展者想覆盖 PyTorch 提供的默认内核,他们仍然可以向其后端注册定制内核,Dispatcher 将仅为您的后端使用它。例如,PyTorch 当前的 max_pool2d 实现将其 indices 作为前向输出的一部分返回,这在 torch_xla 中会产生开销,因此 torch_xla 会注册自己的 max_pool2d 内核。

    • RegistrationDeclarations.h 中,这些运算符在其伴随注释的元数据中具有 dispatch 设置为 False *或* default 设置为 True。

为新后端提供 Autograd 支持#

梯度公式在数学上大多是纯粹的,因此对所有后端都是通用的。PyTorch 通常会向别名 Dispatch Key Autograd 注册一个内核,这意味着所有后端都可以使用它。

对于这些运算符,您不必担心它们的导数公式,只需为 RegistrationDeclarations.h 中的运算符编写前向定义,PyTorch 会自动处理后向传播。

Tensor my_op1(const Tensor& self, const Tensor& other) {
  // call your backend-specific APIs to implement my_op so that
  // it matches PyTorch's native behavior
}
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
  m.impl(<schema_my_op1>, &my_op);
}

在某些情况下,PyTorch 的后向内核实现也是特定于设备的,这样它们就可以从每个后端榨取最大性能。对于这些运算符,您会在 RegistrationDeclarations.h 中看到 `op_backward` 作为*必需注册*。

Tensor my_op2_backward(const Tensor& self, const Tensor& other) {
  // call your backend-specific APIs to implement my_op2_backward so that
  // it matches PyTorch's native behavior
}

// Note backward kernel is still registered to PrivateUse1 instead of AutogradPrivateUse1.
// PyTorch will wrap your backward kernel with proper autograd setup and then link to it in
// my_op2's AutogradPrivateUse1 kernel.
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
  m.impl(<schema_my_op2>, &my_op2);
  m.impl(<schema_my_op2_backward>, &my_op2_backward);
}

在极少数情况下,PyTorch 某些运算符的梯度公式可能存在不适用于所有后端的假设。在这种情况下,后端扩展者可以选择通过向相应的 Dispatch Key(例如,如果您正在为后端使用 PrivateUse1,则为 AutogradPrivateUse1)注册一个来自 `torch::autograd::Function` 的内核来覆盖 PyTorch Autograd 层。

class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
  public:
  static Tensor forward(AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
    at::AutoNonVariableTypeMode g;
    return myadd(self, other);
  }

  static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
    auto grad_output = grad_outputs[0];
    return {grad_output, grad_output};
  }
};

Tensor myadd_autograd(const Tensor& self, const Tensor& other) {
  return MyAddFunction::apply(self, other)[0];
}

// Register the autograd kernel to AutogradPrivateUse1
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
  m.impl(<myadd_schema>, &myadd_autograd);
}

// Register the inference kernel to PrivateUse1
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
  m.impl(<myadd_schema>, &myadd);
}

通过这种技巧,您可以完全控制后端中 `my_add` 运算符的训练和推理行为。这是 `pytorch/xla` 存储库中的一个 示例

构建扩展#

“out-of-tree”(树外)后端通过向 PyTorch 添加 C++ 扩展来支持。一旦您准备好内核和注册,就可以通过编写使用 `setuptools` 来编译 C++ 代码的 `setup.py` 脚本来构建 C++ 扩展。这是来自 `pytorch/xla` 存储库的一个简化示例:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

setup(
    name='torch_xla',
    ext_modules=[
        CppExtension(
            '_XLAC',
            torch_xla_sources,
            include_dirs=include_dirs,
            extra_compile_args=extra_compile_args,
            library_dirs=library_dirs,
            extra_link_args=extra_link_args + \
                [make_relative_rpath('torch_xla/lib')],
        ),
    ],
    cmdclass={
        'build_ext': Build,  # Build is a derived class of BuildExtension
    }
    # more configs...
)

有关更多详细信息,请参阅我们的 C++ 扩展教程

自定义运算符支持#

只要定制运算符由现有 PyTorch 运算符(已由您的后端支持)组成,您的新后端就应该与在 Python 中扩展的 定制运算符 无缝协作,而无需编写任何新内核。

对于 在 C++ 中扩展的自定义运算符,它们通常带有特定于后端的 C++ 内核实现(例如,torchvision 中的 nms 内核) 以及 定制的 Python API(例如 `torch.ops.torchvision.nms`) 。要支持这些运算符,后端扩展者需要为您的后端编写一个 C++ 内核,并将其正确地注册到 Dispatcher 中相应的命名空间,这与支持 PyTorch 原生运算符类似。或者,您也可以在扩展中添加定制 API,例如 `torch_xla.core.functions.nms`,以应对这些临时请求。

JIT 支持#

正如我们在 在 C++ 中注册已分派的运算符 中提到的,通过 `m.impl()` API 注册的内核支持以未装箱(unboxed)和装箱(boxed)两种方式调用。换句话说,您的定制后端也可以像 CPU 或 CUDA 等内部后端一样与我们的 JIT 跟踪/脚本前端配合使用。您甚至可以为 JIT 图上的后端编写专门的优化通道。但由于我们尚未确定 JIT 中的集成点,因此我们在此不讨论,因此当前后端支持将暂时专注于 Eager 前端。

针对原生 PyTorch 后端测试您的后端#

PyTorch 允许使用其 通用设备类型测试框架 在多种设备类型上运行测试。您可以找到有关 测试如何使用它 的详细信息,以及有关 如何添加新设备类型 的信息。添加后,使用通用设备类型测试框架的 PyTorch 测试也将使用您的设备类型运行。有关测试如何实例化的示例,请参阅此 Wiki 页面

使用您的设备类型运行 PyTorch 的现有测试套件对于确保正确性很重要,但并非所有 PyTorch 功能都受每种设备类型支持。通用设备类型测试框架允许大量自定义,以便设备类型可以选择要运行的测试、它们支持的数据类型,甚至在比较张量是否相等时使用的精度。

XLA 是一个不随 PyTorch 分发的、使用通用设备类型测试框架的示例设备类型。请参阅其对通用设备类型测试框架的 扩展,其中包含有关阻止列出测试、阻止列出数据类型以及覆盖测试精度的示例。

通用设备类型测试框架正在积极开发中。要请求功能,请在 PyTorch 的 Github 上提交问题。

向后兼容性#

目前 PyTorch 无法保证已注册运算符的向后兼容性。运算符及其模式可能会根据需要添加/修改/删除。注册的内核必须与 PyTorch 版本*完全*相同。如果 PyTorch 为运算符添加了更多参数(即使带有默认值),您的旧注册将不起作用,直到其更新为匹配 PyTorch 的新签名。

因此,我们*强烈建议*“out-of-tree”(树外)后端扩展者仅同步主要的 PyTorch 版本发布,以最大程度地减少开发中断。PyTorch 按季度发布。后端扩展者应加入 pytorch.slack.com 上的 `#announcement` 频道,以获取最新发布更新。

已知问题及补充说明#

  • 并非所有测试套件都已经是设备通用的。可通过在 PyTorch 代码库中搜索 `instantiate_device_type_tests` 来找到可扩展的测试类,例如 `TestTorchDeviceType, TestViewOps, TestTensorDeviceOps, TestTypePromotion` 等。

  • 在 C++ 中没有为在自定义后端上序列化 Python Tensor 对象提供扩展点。目前,您只能通过修改 PyTorch Tensor `__reduce_ex__` 方法 或在“out-of-tree”(树外)存储库中进行 monkey patching 来扩展它。

  • 如果您的后端不允许直接内存访问,则在支持 view 运算符时应格外注意,因为它们应该共享存储。对 view tensor 所做的更改需要传播到其基本 tensor,反之亦然。

  • 如果您的后端无法与原生 PyTorch 优化器一起工作(例如,需要在后向传播中携带状态,如 torch-xla),则在 C++ 中没有优化器的扩展点。此类用例目前只能通过在“out-of-tree”(树外)存储库中添加自定义 API 或 monkey patching 来实现。

未来工作#

使 PyTorch 中的每个组件都能无缝地扩展到“out-of-tree”(树外)后端需要对 PyTorch 内部进行大量更改。以下是我们正在积极进行的一些项目,它们可能会在未来改善用户体验:

  • 改进通用测试框架的测试覆盖率。

  • 改进 `Math` 内核覆盖率和更全面的测试,以确保 `Math` 内核的行为与其他后端(如 `CPU/CUDA`)匹配。

  • 重构 `RegistrationDeclarations.h` 以携带最少的信息,并尽可能重用 PyTorch 的代码生成。

  • 支持一个后端回退内核,以自动将输入转换为 CPU,并将结果转换回自定义后端。这将允许“完全”的运算符覆盖,即使您没有为每个运算符编写内核。

保持联系#

请使用 PyTorch 开发讨论 进行提问和讨论。如果您有任何功能请求或错误报告,请在 Github 上 提交问题

如果您有兴趣参与上述任何未来工作(例如,在 C++ 中为 PyTorch 运算符添加更多 `Math` 内核),请通过 Github 或 Slack 与我们联系!