快捷方式

OP Lowering Guide

PyTorch 封装了 C++ ATen 张量库,该库在 GPU 和 CPU 上实现了各种操作。Pytorch/XLA 是一个 PyTorch 扩展;它的目的之一是将 PyTorch 操作转换为 XLA 操作。Lowering 定义了一个将更高级别的表示转换为更低级别的表示的过程。在本文档中,我将把 PyTorch 操作转换为 XLA 操作的过程称为 lowering。XLA 编译器还将把 XlaOp 转换为 HLO,但这超出了本文档的范围。我们将把我们尚未提供 XLA lowering 的操作转发到 CPU 并调用 ATen 实现。转发到 CPU 的操作会导致显著的性能下降。我们必须降低所有在模型中使用到的操作才能获得最佳性能。

以下是您可能会从 PyTorch/XLA 调试工具中看到的尚未进行 lowering 的操作的示例

pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward,  Please open a GitHub issue with the above op lowering requests.

开始之前

您应该遵循《为 Pytorch/XLA 做贡献》中的说明来安装所需的依赖项,并从源代码构建 pytorch 和 pytorch/XLA。您不需要访问 TPU 来实现 lowering。建议在工作站上进行实验,并将其配置为使用 XLA:CPU。您可以通过运行以下命令将 Pytorch/XLA 配置为使用 XLA:CPU:

export PJRT_DEVICE=CPU

理解操作

您可以在 native_functions.yaml 中找到 C++ ATen 操作的定义。在从源代码构建 Pytorch/XLA 后,您还会在 xla/torch_xla/csrc/aten_fallback.h/cpp 中找到我们的默认实现(一个转发调用到 PyTorch 原生内核的 boxed kernel)。Pytorch 操作通常可以轻松地映射到 PyTorch 张量 API。如果不是这样,建议在 PyTorch repo 中搜索 PyTorch 的原生实现。目标是将 PyTorch 操作降低到 XLA 操作语义 中定义的 XLA 操作序列。

文件结构

以下所有提到的文件都位于 xla/torch_xla/csrc 文件夹下,除了 codegen/xla_native_functions.yaml

  1. xla_native_functions.yaml 包含所有显式进行 lowering 的算子列表(来自 Core Aten 列表)。复合算子不在此处列出。这里的每个算子名称必须直接匹配 native_functions.yaml 中列出的 pytorch 算子。此文件作为添加新 xla 算子的接口,是 PyTorch 的 codegen 机制 的输入。它生成以下 3 个文件:《XLANativeFunctions.h》、《RegisterXLA.cpp》和《RegisterAutogradXLA.cpp》。

  2. XLANativeFunctions.haten_xla_type.cpp 是 PyTorch 与 pytorch_xla 世界的入口点,并且包含为每个算子手动编写的到 XLA 的 lowering。 XLANativeFunctions.h 是通过 xla_native_functions.yaml 和 PyTorch 核心 native_functions.yaml 文件组合自动生成的,并包含需要在 aten_xla_type.cpp 中定义的内核。这里编写的内核需要使用输入的 at::Tensor 和其他参数来构建 'XLATensor'。最终的 XLATensor 需要在返回到 PyTorch 世界之前转换回 at::Tensor

  3. RegisterXLA.cppRegisterAutogradXLA.cpp 是自动生成的文件,它们将所有 lowering 注册到 PyTorch Dispatcher。它们还包括 out=inplace 操作的自动生成包装器实现。

  4. aten_fallback.h/.cpp 包含我们的 boxed fallback 实现。如果 lowering 未在 xla_native_functions.yaml + aten_xla_type.cpp 中显式定义,并且该操作不是复合的,则将使用 boxed fallback kernel。

  5. tensor_methods.h 包含 XLATensor 的声明。这些声明通常是 XLANativeFunctions.h 中声明的 at::Tensor 节点的一对一映射。

  6. tensor_methods.cpp 包含 tensor_methods.h 中定义的 XLATensor node 的实现。我们根据参数的 ir::Value 构建了相应的 ir::op,并将其包装在 XLATensor 中。Ir 代表中间表示。

  7. ops/ 目录包含所有 ir::ops 的声明和定义。较小的节点可以放在 ops/ops.h/.cpp 中。更复杂的节点可以放在单独的文件中。所有 op 都继承自 ir::ops::Node,并提供一种将输入 ir::Value lowering 到 XlaOp 序列的方法。

单元测试

我们的 CI 会对每次更改和每天的 PyTorch 原生 Python 测试进行运行。如果我们提供了 lowering,这些测试将使用 XLA 实现。我们通常不需要为 PyTorch/XLA 添加额外的 Python 测试,除非我们想验证某些 xla 行为(如动态形状)或者我们因为某些原因跳过了 pytorch 原生测试。如果需要,Python 测试应添加到 xla/test/test_operations.py。我们还需要在 xla/test/cpp/test_aten_xla_tensor.cpp 中添加 CPP 测试。此测试应调用 PyTorch c++ API,并验证我们的实现是否产生与 PyTorch 原生实现相同的结果。我们还需要通过检查 aten::opxla::op 计数器来验证张量是 XLA 张量时是否调用了 xla 实现。

技巧

Lowering 的过程是将 PyTorch 操作分解为 XlaOp 序列。为了提供 PyTorch 操作的良好 lowering,需要对 XLA 的能力有很好的掌握。阅读 XlaOp 文档并查看类似 op 的 lowering 方式是实现这一目标的最佳方法。您可以在 此 Op lowering PR 中找到一个最小的 Op lowering 示例。您也可以在 此 backward lowering PR 中找到一个稍微复杂的示例,其中包含 backward lowering。

我们在 RegisterXLA.cpp 中为某些操作提供了 out=inplace 操作的自动生成包装器实现。在这种情况下,我们只需要 lowering 纯粹的 op。一个例子是 lerp 操作,它在 native_functions.yaml 中有 6 个变体,它们是:

- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor

并将生成函数原型

at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out);

如果在 xla_native_functions.yaml 中添加所有这些,则在 XLANativeFunctions.h 中。然而,如果我们只 lowering lerp.Scalarlerp.Tensor 并检查 RegisterXLA.cpp,我们会看到:

namespace {

at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
    // No device check


  // DeviceGuard omitted
  return torch_xla::lerp(self, end, weight);
}

} // anonymous namespace

at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
  auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight);
  at::_copy_from(wrapper_Scalar_lerp__tmp, self);
  return self;
}

...
  m.impl("lerp_.Scalar",
  TORCH_FN(wrapper_Scalar_lerp_));

代码生成会自动为 lerp_.Scalarlerp.Scalar_out 生成 lowering,使用我们的 lerp.Scalar 实现,而无需我们提供显式 lowering。

总的来说,如果 pytorch 核心中存在同时具有 out-of-place 和 out= 变体的操作,最好为 out-of-place 变体编写 lowering,因为您将免费获得代码生成的 out= lowering。

对于每个节点,我们需要传递一个 ir::OpKind。这里有一个(示例)。您可以在 interned_strings.h 中找到 OpKind 的定义。如果 aten 符号丢失,您可以提交一个类似 的 PR。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源