评价此页

在 C++ 中注册调度操作符#

创建于:2020 年 7 月 22 日 | 最后更新:2024 年 7 月 22 日 | 最后验证:2024 年 11 月 5 日

警告

本教程自 PyTorch 2.4 起已弃用。请参阅 PyTorch 自定义操作符,获取有关使用自定义操作符扩展 PyTorch 的最新指南。

调度器是 PyTorch 的内部组件,负责确定当你调用诸如 torch::add 等函数时实际应运行的代码。这可能并非易事,因为 PyTorch 操作需要处理许多相互“分层”的交叉关注点。以下是它处理的一些示例:

  • 根据输入张量的设备,在操作符的 CPU 和 CUDA 实现之间切换。

  • 根据是否需要自动梯度处理,在操作符的自动梯度和后端实现之间切换。

  • 在需要自动混合精度时应用自动转换。

  • 当操作符在 vmap 调用下运行时应用批处理规则。

  • 如果您正在跟踪模型以进行导出,则跟踪操作的执行。

如果您在 自定义操作符代码中发现自己手动编写 if 语句来处理这些情况,调度器 API 可以帮助组织您的代码。(相反,如果您的自定义操作符非常简单,并且仅用于 CPU 推理,您可能不需要使用调度器,只需使用基本 API 即可。)

在本教程中,我们将描述如何构建自定义操作符注册以使用调度器来组织各种组件。我们假设您熟悉如何注册操作符以及如何编写自定义自动梯度函数

定义模式和后端实现#

调度器的一般原理是,它将操作符的实现划分为多个内核,每个内核实现特定于*调度键*的功能,例如 CPU、CUDA。调度器确定在调用操作符时最高优先级的调度键是什么(这是通过查看张量参数和一些线程局部状态来完成的),并将控制权转移给该调度键的内核。最终效果是,当您调用操作符时,我们首先执行自动梯度内核,然后根据传入张量的设备类型重新调度到后端内核。

让我们看看实现这一过程的各个部分。首先,我们必须定义相关操作符的模式。与简单的 pybind11 风格的操作符注册不同,我们此时实际上不提供操作符的实现;我们只提供一个模式字符串,指定所有其他内核将遵守的操作符的类型签名

TORCH_LIBRARY(myops, m) {
  m.def("myadd(Tensor self, Tensor other) -> Tensor");
}

接下来,我们需要实际提供一些此操作符的实现。具体来说,这是一个非常简单的 CPU 加法实现:

Tensor myadd_cpu(const Tensor& self_, const Tensor& other_) {
  TORCH_CHECK(self_.sizes() == other_.sizes());
  TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CPU);
  TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CPU);
  Tensor self = self_.contiguous();
  Tensor other = other_.contiguous();
  Tensor result = torch::empty(self.sizes(), self.options());
  const float* self_ptr = self.data_ptr<float>();
  const float* other_ptr = other.data_ptr<float>();
  float* result_ptr = result.data_ptr<float>();
  for (int64_t i = 0; i < result.numel(); i++) {
    result_ptr[i] = self_ptr[i] + other_ptr[i];
  }
  return result;
}

我们希望将此函数注册为 myops::myadd 的实现。然而,简单的注册方式 (def("myadd", myadd_cpu)) 会将内核注册为在所有情况下运行,即使张量不是 CPU 张量也是如此!(在内部,我们将其称为“全包”内核,因为它包含所有情况。)为了确保 myadd_cpu 仅适用于 CPU 张量,我们可以使用 TORCH_LIBRARY_IMPL 宏:

TORCH_LIBRARY_IMPL(myops, CPU, m) {
  m.impl("myadd", myadd_cpu);
}

TORCH_LIBRARY_IMPL 允许我们注册特定调度键(在此例中为 CPU)的操作符实现。每次调用 impl 都会将一个 CPU 内核与相应的操作符(我们之前在 TORCH_LIBRARY 块中定义)关联起来。如果我们还有一个 CUDA 实现 myadd_cuda,我们可以在单独的 TORCH_LIBRARY_IMPL 块中注册它:

TORCH_LIBRARY_IMPL(myops, CUDA, m) {
  m.impl("myadd", myadd_cuda);
}

这些注册可以跨文件甚至跨库边界拆分;例如,您可以将这两个 TORCH_LIBRARY_IMPL 块编译成单独的 myops_cpumyops_cuda 动态库。一般来说,您的注册结构将如下所示:

  1. 一个单独的 TORCH_LIBRARY,在一个中心位置列出您的命名空间中的所有自定义操作符。

  2. 每个调度键一个 TORCH_LIBRARY_IMPL,用于注册该键的实现(例如,CPU 或 CUDA)。如果需要,您可以进一步将 TORCH_LIBRARY_IMPL 块细分为每个操作符一个块。如果您为每个操作符实现都有一个单独的文件,但不想在头文件中公开操作符,这会很方便;您可以直接将注册放在定义操作符的 cpp 文件中。

注意

您是否知道也可以为 PyTorch 中现有的核心运算符编写 TORCH_LIBRARY_IMPL 块?这就是 PyTorch 的 XLA 支持的实现方式:torch_xla 库包含一个 TORCH_LIBRARY_IMPL,它为 XLA 调度键上的所有基本运算符提供实现。

对于不需要自动梯度的操作符#

注意:本节仅适用于 PyTorch >= 1.10 版本。

在下一节中,我们将讨论如何为操作符添加自动梯度支持。但是对于不需要自动梯度支持的操作,应注册以下内核以提高可用性并使您的操作表现得像 PyTorch 的内置操作符。

TORCH_LIBRARY_IMPL(myops, Autograd, m) {
  m.impl(op, autogradNotImplementedFallback());
}

上述代码注册了一个 Autograd 内核,该内核在正向传播时附加一个虚拟的 NotImplemented 节点(保留输入的 require_grad 属性)。在反向传播时,NotImplemented 节点会引发错误。这对于大型模型的调试很有帮助,因为在以前,很难精确地指出在正向传播过程中 requires_grad 属性是在哪里丢失的。

原地操作或视图操作#

为确保正确性和最佳性能,如果您的操作符修改了输入,或者返回了一个与输入之一别名的张量,则应采取两个额外步骤:

  1. 除了上述 Autograd 内核之外,还要注册一个 ADInplaceOrView 内核。此内核处理必要的簿记,以确保原地或视图操作的正确性。重要的是,此 ADInplaceOrView 内核只能与 autogradNotImplementedFallback 一起使用。

TORCH_LIBRARY_IMPL(myops, Autograd, m) {
  m.impl(op, autogradNotImplementedFallback());
}
TORCH_LIBRARY_IMPL(myops, ADInplaceOrView, m) {
  m.impl(op, autogradNotImplementedInplaceOrViewFallback());
}
  1. 上面注册的 AutogradADInplaceOrView 封装内核依赖于其逻辑中的操作符模式信息。如果您的操作符修改了输入,或者返回了与输入之一别名的张量,那么确保您的模式正确反映这一点非常重要。有关如何注释模式的更多信息,请参阅此处

添加自动梯度支持#

此时,我们有一个同时具有 CPU 和 CUDA 实现的操作符。我们如何为其添加自动梯度支持?您可能已经猜到,我们将注册一个自动梯度内核(类似于自定义自动梯度函数教程中描述的)!然而,有一个转折:与 CPU 和 CUDA 内核不同,自动梯度内核需要*重新调度*:它需要回调到调度器以获取推理内核,例如 CPU 或 CUDA 实现。

因此,在我们编写自动梯度内核之前,让我们编写一个*调度函数*,它调用调度器来查找适合您的操作符的正确内核。此函数构成了您的操作符的公共 C++ API——实际上,PyTorch C++ API 中的所有张量函数都在底层以相同的方式调用调度器。调度函数如下所示:

Tensor myadd(const Tensor& self, const Tensor& other) {
  static auto op = torch::Dispatcher::singleton()
    .findSchemaOrThrow("myops::myadd", "")
    .typed<decltype(myadd)>();
  return op.call(self, other);
}

我们来分解一下

  • 在第一行中,我们从调度器中查找与我们要调度到的操作符对应的类型化操作符句柄。findSchemaOrThrow 接受两个参数:操作符的(命名空间限定)名称和操作符的重载名称(通常只是空字符串)。typed 将动态类型句柄转换为静态类型句柄(进行运行时测试以确保您提供了正确的 C++ 类型),以便我们可以对其进行正常的 C++ 调用。我们传入 decltype(myadd),因为调度函数的类型与注册到调度器的底层内核的类型相同。

    为了性能,此计算在静态变量中完成,因此我们只需执行(缓慢的)查找一次。如果您错误地输入了要调用的操作符名称,此查找将在您第一次调用此函数时出错。

  • 在第二行,我们只需使用传递给调度函数的所有参数来 call 操作符句柄。这实际上将调用调度器,最终控制权将转移到适合此调用的任何内核。

有了调度函数,我们现在可以编写自动梯度内核了:

class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
 public:
  static Tensor forward(
      AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
    at::AutoNonVariableTypeMode g;
    return myadd(self, other);
  }

  static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
    auto grad_output = grad_outputs[0];
    return {grad_output, grad_output};
  }
};

Tensor myadd_autograd(const Tensor& self, const Tensor& other) {
  return MyAddFunction::apply(self, other)[0];
}

自动梯度函数的编写与 torch::autograd::Function 的正常使用一样,不同之处在于我们不再直接在 forward() 中编写实现,而是:

  1. 使用 at::AutoNonVariableTypeMode RAII 守卫关闭自动梯度处理,然后

  2. 调用调度函数 myadd 回调到调度器。

如果没有 (1),您的调用将无限循环(并堆栈溢出),因为 myadd 会将您发送回此函数(因为最高优先级的调度键仍然是自动梯度)。有了 (1),自动梯度将从考虑的调度键集合中排除,我们将转到下一个处理程序,即 CPU 和 CUDA。

我们现在可以以注册 CPU/CUDA 函数相同的方式注册此函数:

TORCH_LIBRARY_IMPL(myops, Autograd, m) {
  m.impl("myadd", myadd_autograd);
}

注意

在此示例中,我们将内核注册到 Autograd,它将该内核安装为所有后端的自动梯度内核。您还可以通过使用相应的后端特定调度键(例如,AutogradCPUAutogradCUDA)来注册特定后端的优化内核。要更详细地探索这些和其他调度键选项,请查看 torch/_python_dispatcher.py 中提供的 PythonDispatcher 工具。

超越自动梯度#

在某种意义上,调度器并没有做太多事情:它所做的只是实现一个被美化了的 if 语句,类似于这样:

class MyAddFunction : ... {
public:
  static Tensor forward(
    AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {

    if (self.device().type() == DeviceType::CPU) {
      return add_cpu(self, other);
    } else if (self.device().type() == DeviceType::CUDA) {
      return add_cuda(self, other);
    } else {
      TORCH_CHECK(0, "Unsupported device ", self.device().type());
    }
  }
  ...
}

那为什么要使用调度器呢?有几个原因:

  1. 它是去中心化的。您可以组装操作符的所有部分(CPU、CUDA、自动梯度),而无需编写单个、集中的 if 语句来引用它们所有。重要的是,第三方可以为其他方面注册额外的实现,而无需修补操作符的原始定义。我们将在为新后端扩展调度器中讨论更多关于扩展调度器的问题。

  2. 它支持的调度键比 CPU、CUDA 和自动梯度更多。您可以在 c10/core/DispatchKey.h 中查看 PyTorch 中当前实现的所有调度键的完整列表。这些调度键实现了操作符的各种可选功能,如果您决定希望您的自定义操作符支持此功能,您只需为相应的键注册一个内核。

  3. 调度器实现了对盒装回退函数的支持,这些函数可以实现一次并应用于系统中的所有操作符。盒装回退可用于为调度键提供默认行为;如果您使用调度器来实现您的操作符,您也将选择所有这些操作的回退。

以下是一些您可能需要为其定义操作符的特定调度键。

自动混合精度(Autocast)#

自动转换调度键实现了对自动混合精度 (AMP) 的支持。自动转换包装器内核通常在运行操作之前,将输入的 float16float32 CUDA 张量转换为某种首选精度。例如,浮点 CUDA 张量上的矩阵乘法和卷积通常在 float16 中运行得更快,并使用更少的内存,而不会影响收敛。自动转换包装器仅在启用自动转换的上下文中有效。

这是一个假想的自定义矩阵乘法的自动转换包装器,以及它的注册:

// Autocast-specific helper functions
#include <ATen/autocast_mode.h>

Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return mymatmul(at::autocast::cached_cast(at::kHalf, self),
                  at::autocast::cached_cast(at::kHalf, other));
}

TORCH_LIBRARY_IMPL(myops, Autocast, m) {
  m.impl("mymatmul", mymatmul_autocast);
}

cached_cast(kHalf, tensor)tensor 转换为 float16,如果 tensor 是 CUDA 且 float32,否则它保持 tensor 不变(参见原生自动转换操作的资格策略)。这确保了如果网络在 float16float32 CUDA 张量的任意组合上调用 mymatmulmymatmul 会以 float16 运行。同时,对非 CUDA、整数类型或 float64 输入的 mymatmul 调用不受影响。建议在您自己的自动转换包装器中使用 cached_cast 来遵循原生资格策略,但这不是必需的。例如,如果您想强制所有输入类型都执行 float16,您可以 return mymatmul(self.half(), other.half()); 而不是使用 cached_cast

请注意,与我们的自动梯度内核一样,我们在重新调度之前将 Autocast 键从调度中排除。

默认情况下,如果未提供自动转换包装器,我们将直接回退到常规操作符实现(不进行自动转换)。(我们没有将 myadd 用于此示例,因为逐点加法不需要自动转换,只需回退即可。)

何时注册自动转换包装器?不幸的是,对于操作符的首选精度没有明确的规则。您可以通过查看转换列表来了解某些原生操作符的首选精度。一般指导原则:

  • 执行归约的操作符可能应该以 float32 执行,

  • 任何底层执行卷积或通用矩阵乘法的操作符可能都应该以 float16 执行,并且

  • 其他具有多个浮点张量输入的操作符应将其标准化为通用精度(除非实现支持不同精度的输入)。

如果您的自定义操作符属于第三类,promote_type 模板有助于找出输入张量中存在的最大浮点类型,这是执行类型的最安全选择。

#include <ATen/autocast_mode.h>

Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  // The required at::kHalf argument is an optimistic initial guess.
  auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
  return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
                              at::autocast::cached_cast(exec_type, t1));
}

如果您的自定义操作符是自动梯度启用的,则您只需为注册自动梯度包装器的同名函数编写并注册一个自动转换包装器。例如,如果您想为自动梯度部分中所示的 myadd 函数提供自动转换包装器,您所需要的只是:

Tensor myadd_autocast(const Tensor& self, const Tensor& other) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return myadd(at::autocast::cached_cast(<desired dtype>, self),
               at::autocast::cached_cast(<desired dtype>, other));
}

TORCH_LIBRARY_IMPL(myops, Autocast, m) {
  m.impl("myadd", myadd_autocast);
}

没有单独的技巧来使反向方法与自动转换兼容。然而,您自定义自动梯度函数中定义的反向方法将以与自动转换设置为正向方法相同的 dtype 运行,因此您应该选择一个适合您的正向和反向方法的 <desired dtype>

批量化(Batched)#

批量张量允许您以逐个样本的方式编写代码,然后在 vmap 调用下运行时自动进行批量处理。编写批量处理规则的 API 目前正在开发中,但一旦稳定下来,您可以通过在 Batched 调度键注册内核来为您的操作符添加 vmap 支持。

跟踪器(Tracer)#

Tracer 调度键实现了在运行 torch.jit.trace 时将操作符调用记录到跟踪中的支持。我们打算提供一个盒式回退,它将实现任意操作的跟踪,请参阅 issue #41478 以跟踪进度。