Op 降级指南¶
PyTorch 封装了 C++ ATen 张量库,该库在 GPU 和 CPU 上实现了广泛的操作。Pytorch/XLA 是一个 PyTorch 扩展;其目的之一是将 PyTorch 操作转换为 XLA 操作。降级定义了一个将高级表示转换为低级表示的过程。在本文档中,我将把将 PyTorch 操作转换为 XLA 操作的过程称为降级。XLA 编译器也会将 XlaOp 降级为 HLO,但这超出了本文档的范围。我们将把尚未提供 XLA 降级的所有操作转发到 CPU 并调用 ATen 实现。转发到 CPU 的操作会导致显著的性能下降。我们必须将模型中使用的所有操作降级,才能获得最佳性能。
这是您可能会从 PyTorch/XLA 调试工具中看到的关于尚未降级操作的示例
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests.
开始之前¶
您应该遵循 为 Pytorch/XLA 做贡献 中的说明来安装必需的依赖项,并从源代码构建 pytorch 和 pytorch/XLA。您不需要 TPU 访问权限来实现降级。建议在工作站上进行实验,并将其配置为使用 XLA:CPU。您可以通过运行以下命令来配置 Pytorch/XLA 以使用 XLA:CPU:
export PJRT_DEVICE=CPU
理解操作¶
您可以在 native_functions.yaml 中找到 C++ ATen 操作的定义。在从源代码构建 Pytorch/XLA 后,您还将在 xla/torch_xla/csrc/aten_fallback.h/cpp
中找到我们的默认实现(一个包装的内核,它将调用转发到 PyTorch 本机内核)。Pytorch 操作通常可以轻松映射到 PyTorch 张量 API。如果不是这样,建议在 PyTorch repo 中搜索 PyTorch 本机实现。目标是将 PyTorch 操作降级为 XLA 操作语义 中定义的 XLA 操作序列。
文件结构¶
以下提到的所有文件都位于 xla/torch_xla/csrc
文件夹下,但 codegen/xla_native_functions.yaml
除外。
xla_native_functions.yaml
包含所有显式降级运算符的列表(来自 Core Aten 列表)。复合运算符不在此处列出。此处每个运算符名称必须直接匹配 native_functions.yaml 中列出的 pytorch 运算符。此文件作为添加新 xla 运算符的接口,并且是 PyTorch 代码生成机器 的输入。它生成以下 3 个文件:XLANativeFunctions.h
、RegisterXLA.cpp
和RegisterAutogradXLA.cpp
。XLANativeFunctions.h
和aten_xla_type.cpp
是 PyTorch 进入 pytorch_xla 世界的入口点,其中包含为每个运算符手动编写的到 XLA 的降级。XLANativeFunctions.h
是通过xla_native_functions.yaml
和 PyTorch 核心native_functions.yaml
文件组合自动生成的,并包含需要定义在aten_xla_type.cpp
中的内核声明。此处编写的内核需要使用输入的at::Tensor
和其他参数来构建 'XLATensor'。生成的XLATensor
在返回到 PyTorch 世界之前需要转换回at::Tensor
。RegisterXLA.cpp
和RegisterAutogradXLA.cpp
是自动生成的文件,它们将所有降级注册到 PyTorch Dispatcher。它们还包括out=
和inplace
运算符的自动生成包装实现。aten_fallback.h/.cpp
包含我们的包装式回退实现。如果降级未在xla_native_functions.yaml
+aten_xla_type.cpp
中显式定义,并且该运算符不是复合运算符,则将使用包装式回退内核。tensor_methods.h
包含XLATensor
声明。这些声明通常是一对一地映射我们在XLANativeFunctions.h
中声明的at::Tensor
节点。tensor_methods.cpp
包含tensor_methods.h
中定义的XLATensor 节点
的实现。我们从参数的ir::Value
构建了相应的ir::op
,并将其包装在XLATensor
中。Ir 代表中间表示。ops/
目录包含所有ir::ops
声明和定义。较小的节点可以放在ops/ops.h/.cpp
中。更复杂的节点可以放在单独的文件中。所有 ops 都继承自ir::ops::Node
,并提供了一种将输入ir::Value
降级为XlaOp
序列的方法。
单元测试¶
我们的 CI 会在每次更改时以及每天运行 PyTorch 本机 Python 测试。如果提供了降级,这些测试将使用 XLA 实现。我们通常不需要为 PyTorch/XLA 添加额外的 Python 测试,除非我们想验证某些 xla 行为(如动态形状)或因某种原因跳过了 pytorch 本机测试。如果需要,应将 Python 测试添加到 xla/test/test_operations.py
。我们还需要在 xla/test/cpp/test_aten_xla_tensor.cpp
中添加 CPP 测试。此测试应调用 PyTorch c++ API,并验证我们的实现是否产生与 PyTorch 本机实现相同的结果。我们还需要通过检查 aten::op
和 xla::op
计数器来验证在张量是 XLA 张量时是否调用了 xla 实现。
技巧¶
降级过程是将 PyTorch 操作分解为 XlaOp 序列。要提供 PyTorch 操作的良好降级,需要对 XLA 的能力有很好的掌握。阅读 XlaOp 文档并查看类似 op 的降级方式是实现这一点的最佳方法。您可以在 此 Op 降级 PR 中找到一个最小的 Op 降级示例。您也可以在 此反向降级 PR 中找到一个稍复杂的反向降级示例。
我们在 RegisterXLA.cpp
中为某些运算符自动生成了 out=
和 inplace
运算符的包装实现。在这种情况下,我们只需要降级常规运算符。一个例子是 lerp
运算符,它在 native_functions.yaml
中有 6 个变体,它们是:
- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor
并将生成函数原型
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out);
如果在 xla_native_functions.yaml
中添加所有这些,则会在 XLANativeFunctions.h
中生成。但是,如果我们只降级 lerp.Scalar
和 lerp.Tensor
并检查 RegisterXLA.cpp
,我们将看到:
namespace {
at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
// No device check
// DeviceGuard omitted
return torch_xla::lerp(self, end, weight);
}
} // anonymous namespace
at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight);
at::_copy_from(wrapper_Scalar_lerp__tmp, self);
return self;
}
...
m.impl("lerp_.Scalar",
TORCH_FN(wrapper_Scalar_lerp_));
代码生成器将自动为 lerp_.Scalar
和 lerp.Scalar_out
生成降级,这些降级使用我们的 lerp.Scalar
实现,而无需我们提供显式降级。
总的来说,如果 PyTorch 核心中存在一个同时具有异地变体和 out= 变体的运算符,最好为异地变体编写降级,因为您将免费获得一个代码生成的 out= 降级。
对于每个节点,我们需要传递一个 ir::OpKind
。这里是(示例)。您可以在 interned_strings.h 中找到 OpKind
定义。如果 aten 符号丢失,您可以提交一个类似 这个 的 PR。