自定义 C++ 和 CUDA 算子#
创建日期:2024 年 6 月 18 日 | 最后更新:2026 年 1 月 20 日 | 最后验证:2024 年 11 月 5 日
作者: Richard Zou
如何将用 C++/CUDA 编写的自定义算子集成到 PyTorch 中
如何使用
torch.library.opcheck测试自定义算子
PyTorch 2.4 或更高版本(如果使用稳定 ABI,则为 PyTorch 2.10 或更高版本)
对 C++ 和 CUDA 编程有基本了解
注意
本教程也适用于 AMD ROCm,无需进行额外修改。
PyTorch 提供了大量用于处理张量的算子库(例如 torch.add, torch.sum 等)。但是,您可能希望将新的自定义算子引入 PyTorch。本教程演示了编写 C++/CUDA 自定义算子的推荐路径。
在本教程中,我们将演示如何编写一个融合乘加(fused multiply-add)的 C++ 和 CUDA 算子,使其能够与 PyTorch 子系统配合使用。该操作的语义如下:
def mymuladd(a: Tensor, b: Tensor, c: float):
return a * b + c
您可以在 extension-cpp 存储库中找到本教程的完整工作示例,其中包含两种并行实现:
extension_cpp/:使用标准的 ATen/LibTorch API。
extension_cpp_stable/:使用 LibTorch 稳定 ABI 支持的 API(推荐用于 PyTorch 2.10+)。
您应该使用哪个 API?
ABI 稳定版 LibTorch API(推荐):如果您使用的是 PyTorch 2.10+,我们建议使用 ABI 稳定版 API。它允许您构建单个 wheel 包,该包可在多个 PyTorch 版本(2.10、2.11、2.12 等)上运行,从而降低维护多个 PyTorch 发行版的负担。有关更多详细信息,请参阅下方的 LibTorch 稳定 ABI(与 PyTorch 版本无关) 部分。
非 ABI 稳定版 LibTorch API:如果您需要稳定 ABI 中尚不可用的 API,或者您的目标是 2.10 之前的 PyTorch 版本,请使用此 API。请注意,您需要为您想要支持的每个 PyTorch 版本构建单独的 wheel 包。
下方的代码片段通过选项卡展示了两种实现方式,默认显示 ABI 稳定版 API。
设置构建系统#
如果您正在开发自定义 C++/CUDA 代码,则必须对其进行编译。请注意,如果您正在对接一个已经具有预编译 C++/CUDA 代码绑定的 Python 库,您可以考虑编写一个自定义 Python 算子(自定义 Python 算子)。
使用 torch.utils.cpp_extension 为 PyTorch 编译自定义 C++/CUDA 代码。C++ 扩展既可以通过 setuptools “提前(ahead of time)”构建,也可以通过 load_inline “即时(just in time)”构建;我们将重点介绍“提前”构建的方式。
使用 cpp_extension 就像编写 setup.py 一样简单:
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name="extension_cpp",
ext_modules=[
cpp_extension.CppExtension(
"extension_cpp",
["muladd.cpp"],
extra_compile_args={
"cxx": [
# define Py_LIMITED_API with min version 3.9 to expose only the stable
# limited API subset from Python.h
"-DPy_LIMITED_API=0x03090000",
# define TORCH_TARGET_VERSION with min version 2.10 to expose only the
# stable API subset from torch
"-DTORCH_TARGET_VERSION=0x020a000000000000",
]
},
py_limited_api=True)], # Build 1 wheel across multiple Python versions
cmdclass={'build_ext': cpp_extension.BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version
)
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name="extension_cpp",
ext_modules=[
cpp_extension.CppExtension(
"extension_cpp",
["muladd.cpp"],
extra_compile_args={
"cxx": [
"-DPy_LIMITED_API=0x03090000",
]
},
py_limited_api=True)],
cmdclass={'build_ext': cpp_extension.BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}}
)
如果您需要编译 CUDA 代码(例如 .cu 文件),则改用 torch.utils.cpp_extension.CUDAExtension。请参阅 extension-cpp 以获取有关如何设置此项的示例。
CPython 无关性(CPython Agnosticism)#
上述示例代表了我们所说的 CPython 无关 wheel 包,这意味着我们正在构建一个可以在多个 CPython 版本上运行的单一 wheel 包(类似于纯 Python 包)。CPython 无关性对于最大限度地减少您的自定义库需要支持和发布的 wheel 数量是非常理想的。我们希望支持的最低版本是 3.9,因为它是目前受支持的最旧版本,因此我们在整个安装代码中使用相应的十六进制代码和说明符。我们建议在您想要支持的最低 CPython 版本的环境中构建扩展,以最大限度地减少未知行为,因此,此处我们在 CPython 3.9 环境中构建扩展。构建完成后,这个单一的 wheel 将可以在任何 CPython 3.9+ 环境中运行。要实现这一点,有三行关键代码需要注意。
第一行是在 extra_compile_args 中将 Py_LIMITED_API 指定为您想要支持的最低 CPython 版本。
extra_compile_args={"cxx": ["-DPy_LIMITED_API=0x03090000"]},
定义 Py_LIMITED_API 标志有助于验证扩展程序确实仅使用了 CPython 稳定有限 API,这是构建 CPython 无关 wheel 的前提条件。如果不满足此要求,则有可能构建出一个看起来与 CPython 无关但实际上会在其他 CPython 环境中崩溃,或者更糟糕的是,悄无声息地出现错误的 wheel。请务必避免使用不稳定的 CPython API,例如来自 libtorch_python 的 API(特别是 pytorch/python 绑定),并仅使用来自 libtorch 的 API(ATen 对象、算子和调度程序)。我们强烈建议定义 Py_LIMITED_API 标志,以帮助确定该扩展是否合规并可安全作为 CPython 无关 wheel 使用。请注意,定义此标志并不能完全保证构建出的 wheel 是 CPython 无关的,但这比“荒蛮西部”式的方法要好。 Python 文档中提到了几个注意事项,您应该自行测试并验证该 wheel 对于相关的 CPython 版本是否真正无关。
第二行和第三行指定了 py_limited_api,告知 setuptools 您打算构建一个 CPython 无关 wheel,这将相应地影响 wheel 的命名。
setup(name="extension_cpp",
ext_modules=[
cpp_extension.CppExtension(
...,
py_limited_api=True)], # Build 1 wheel across multiple Python versions
...,
options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version
)
必须将 py_limited_api=True 作为参数传递给 CppExtension/CUDAExtension,并作为 "bdist_wheel" 命令的一个选项,并指定支持的最低 CPython 版本(在本例中为 3.9)。因此,我们教程中的 setup 将构建一个命名正确的 wheel,它可以安装在多个 >=3.9 的 CPython 版本上。
如果您的扩展使用了稳定有限集之外的 CPython API,那么您就无法构建 CPython 无关 wheel!在这种情况下,您应该为每个 CPython 版本构建一个单独的 wheel。
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name="extension_cpp",
ext_modules=[
cpp_extension.CppExtension(
"extension_cpp",
["muladd.cpp"])],
cmdclass={'build_ext': cpp_extension.BuildExtension},
)
LibTorch 稳定 ABI(与 PyTorch 版本无关)#
除了 CPython 无关性之外,还有第二个 wheel 兼容性维度:LibTorch 无关性。虽然 CPython 无关性允许构建一个可在多个 Python 版本(3.9、3.10、3.11 等)上运行的单一 wheel,但 LibTorch 无关性允许构建一个可在多个 PyTorch 版本(2.10、2.11、2.12 等)上运行的单一 wheel。这两个概念是正交的,可以结合使用。
要实现 LibTorch 无关性,您必须使用 ABI 稳定版 LibTorch API,它提供了一个用于与 PyTorch 张量和算子交互的稳定 API。例如,您必须使用 torch::stable::Tensor 而不是 at::Tensor。有关稳定 ABI 的全面文档(包括迁移指南、支持的类型和基于栈的 API 约定),请参阅 LibTorch 稳定 ABI 文档。
稳定 ABI 的 setup.py 包含 TORCH_TARGET_VERSION=0x020a000000000000,这表明该扩展的目标是 LibTorch 稳定 ABI,且支持的最低 PyTorch 版本为 2.10。版本格式为:[主版本 1 字节][次版本 1 字节][补丁版本 1 字节][ABI 标签 5 字节],因此 2.10.0 = 0x020a000000000000。
如果稳定 API/ABI 中没有您需要的功能,您可以使用非 ABI 稳定版 LibTorch API,但您需要为您想要支持的每个 PyTorch 版本构建单独的 wheel。
定义自定义算子并添加后端实现#
首先,让我们编写一个计算 mymuladd 的 C++ 函数。
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
torch::stable::Tensor mymuladd_cpu(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
double c) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = result.mutable_data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
}
return result;
}
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>
at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
}
return result;
}
为了从 PyTorch 的 Python 前端使用它,我们需要使用 TORCH_LIBRARY(或 STABLE_TORCH_LIBRARY)宏将其注册为 PyTorch 算子。这将自动把该算子绑定到 Python 中。
算子注册是一个两步过程:
定义算子 - 此步骤确保 PyTorch 识别该新算子。
注册后端实现 - 在此步骤中,将各种后端(如 CPU 和 CUDA)的实现与算子关联起来。
定义算子#
要定义一个算子,请按照以下步骤操作:
为算子选择一个命名空间。我们建议命名空间使用您的顶级项目名称;在本教程中,我们将使用“extension_cpp”。
提供一个模式字符串(schema string),指定算子的输入/输出类型,以及输入张量是否会被原地修改(mutate)。除了 Tensor 和 float,我们还支持更多类型;请参阅 自定义算子手册 以获取更多详细信息。
如果您编写的算子可以修改其输入张量,请参阅此处(创建可修改算子)了解如何指定这一点。
STABLE_TORCH_LIBRARY(extension_cpp, m) {
// Note that "float" in the schema corresponds to the C++ double type
// and the Python float type.
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
}
TORCH_LIBRARY(extension_cpp, m) {
// Note that "float" in the schema corresponds to the C++ double type
// and the Python float type.
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
}
这使得该算子可以通过 torch.ops.extension_cpp.mymuladd 从 Python 中调用。
注册算子的后端实现#
使用 TORCH_LIBRARY_IMPL(或 STABLE_TORCH_LIBRARY_IMPL)为算子注册后端实现。
请注意,我们使用 TORCH_BOX() 包装了函数指针——这是稳定 ABI 函数正确处理参数装箱/拆箱所必需的。
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu));
}
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", &mymuladd_cpu);
}
如果您还有 mymuladd 的 CUDA 实现,可以在单独的 TORCH_LIBRARY_IMPL(或 STABLE_TORCH_LIBRARY_IMPL)块中注册它。
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/c/shim.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
}
torch::stable::Tensor mymuladd_cuda(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
double c) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = result.mutable_data_ptr<float>();
int numel = a_contig.numel();
// For now, we rely on the raw shim API to get the current CUDA stream.
// This will be improved in a future release.
// When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
// check the error code and throw an appropriate runtime_error otherwise.
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
return result;
}
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda));
}
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
}
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
int numel = a_contig.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
return result;
}
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
m.impl("mymuladd", &mymuladd_cuda);
}
为算子添加 torch.compile 支持#
要为算子添加 torch.compile 支持,我们必须添加一个 FakeTensor 内核(也称为“元内核”或“抽象实现”)。FakeTensor 是具有元数据(如形状、数据类型、设备)但没有实际数据的张量:算子的 FakeTensor 内核指定了如何根据输入张量的元数据计算输出张量的元数据。FakeTensor 内核应返回您选择的、具有正确元数据(形状/步长/dtype/设备)的虚拟张量。
我们建议通过 torch.library.register_fake API 从 Python 执行此操作,尽管也可以从 C++ 执行(有关详细信息,请参阅 自定义算子手册)。
# Important: the C++ custom operator definitions should be loaded first
# before calling ``torch.library`` APIs that add registrations for the
# C++ custom operator(s). The following import loads our
# C++ custom operator definitions.
# Note that if you are striving for Python agnosticism, you should use
# the ``load_library(...)`` API call instead. See the next section for
# more details.
from . import _C
@torch.library.register_fake("extension_cpp::mymuladd")
def _(a, b, c):
torch._check(a.shape == b.shape)
torch._check(a.dtype == torch.float)
torch._check(b.dtype == torch.float)
torch._check(a.device == b.device)
return torch.empty_like(a)
设置 Python/C++ 混合注册#
在本教程中,我们在 C++ 中定义了一个自定义算子,在 C++ 中添加了 CPU/CUDA 实现,并在 Python 中添加了 FakeTensor 内核和反向传播公式。这些注册的加载(或导入)顺序非常重要(顺序错误会导致错误)。
要将自定义算子与 Python/C++ 混合注册一起使用,我们必须先加载包含自定义算子定义的 C++ 库,然后调用 torch.library 注册 API。这可以通过以下三种方式实现:
加载包含自定义算子定义的 C++ 库的第一种方法是为 _C 定义一个虚拟 Python 模块。然后,在 Python 中,当您使用
import _C导入该模块时,与该扩展对应的.so文件将被加载,并且TORCH_LIBRARY和TORCH_LIBRARY_IMPL静态初始化程序将运行。可以使用PYBIND11_MODULE创建虚拟 Python 模块,但您会注意到这无法通过Py_LIMITED_API编译,因为pybind11并不承诺仅使用稳定的受限 CPython API!使用以下代码,很遗憾您无法为您的扩展构建 CPython 无关 wheel!(铺垫:我想知道第二种方法是什么 ;) )。
// in, say, not_agnostic/csrc/extension_BAD.cpp
#include <pybind11/pybind11.h>
PYBIND11_MODULE("_C", m) {}
# in, say, extension/__init__.py
from . import _C
在本教程中,由于我们重视构建一个适用于多个 CPython 版本的单一 wheel 的能力,我们将用稳定 API 调用替换不稳定的
PYBIND11调用。下面的代码可以使用-DPy_LIMITED_API=0x03090000进行编译,并成功为我们的_C扩展创建一个虚拟 Python 模块,以便可以从 Python 导入它。有关更多详细信息,请参阅 extension_cpp/__init__.py 和 extension_cpp/csrc/muladd.cpp。
#include <Python.h>
extern "C" {
/* Creates a dummy empty _C module that can be imported from Python.
The import from Python will load the .so consisting of this file
in this extension, so that the TORCH_LIBRARY static initializers
below are run. */
PyObject* PyInit__C(void)
{
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_C", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
NULL, /* methods */
};
return PyModule_Create(&module_def);
}
}
# in, say, extension/__init__.py
from . import _C
如果您想在 C++ 自定义算子中完全避免使用
Python.h,可以在 Python 中使用torch.ops.load_library("/path/to/library.so")来加载从扩展编译得到的.so文件。请注意,使用这种方法,不会为该扩展创建_CPython 模块,因此您无法在 Python 中调用import _C。与依赖 import 语句触发自定义算子注册不同,torch.ops.load_library("/path/to/library.so")可以实现这一点。随之而来的挑战变成了如何定位.so文件以便加载它们,这并非总是显而易见的。
import torch
from pathlib import Path
so_files = list(Path(__file__).parent.glob("_C*.so"))
assert (
len(so_files) == 1
), f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
from . import ops
为算子添加训练(自动微分)支持#
使用 torch.library.register_autograd 为算子添加训练支持。优先使用此方法,而不是直接使用 Python torch.autograd.Function(有关详细信息,请参阅 自定义算子手册)。
def _backward(ctx, grad):
a, b = ctx.saved_tensors
grad_a, grad_b = None, None
if ctx.needs_input_grad[0]:
grad_a = grad * b
if ctx.needs_input_grad[1]:
grad_b = grad * a
return grad_a, grad_b, None
def _setup_context(ctx, inputs, output):
a, b, c = inputs
saved_a, saved_b = None, None
if ctx.needs_input_grad[0]:
saved_b = b
if ctx.needs_input_grad[1]:
saved_a = a
ctx.save_for_backward(saved_a, saved_b)
# This code adds training support for the operator. You must provide us
# the backward formula for the operator and a `setup_context` function
# to save values to be used in the backward.
torch.library.register_autograd(
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)
请注意,反向传播(backward)必须是由 PyTorch 可理解的算子组成的。如果您希望在反向传播中使用另一个自定义 C++ 或 CUDA 内核,则必须将其包装成一个自定义算子。
如果我们有自己的自定义 mymul 内核,我们需要将其包装成一个自定义算子,然后从反向传播中调用它。
torch::stable::Tensor mymul_cpu(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = result.mutable_data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i];
}
return result;
}
STABLE_TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
m.def("mymul(Tensor a, Tensor b) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu));
m.impl("mymul", TORCH_BOX(&mymul_cpu));
}
at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i];
}
return result;
}
TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
m.def("mymul(Tensor a, Tensor b) -> Tensor");
}
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", &mymuladd_cpu);
m.impl("mymul", &mymul_cpu);
}
def _backward(ctx, grad):
a, b = ctx.saved_tensors
grad_a, grad_b = None, None
if ctx.needs_input_grad[0]:
grad_a = torch.ops.extension_cpp.mymul.default(grad, b)
if ctx.needs_input_grad[1]:
grad_b = torch.ops.extension_cpp.mymul.default(grad, a)
return grad_a, grad_b, None
def _setup_context(ctx, inputs, output):
a, b, c = inputs
saved_a, saved_b = None, None
if ctx.needs_input_grad[0]:
saved_b = b
if ctx.needs_input_grad[1]:
saved_a = a
ctx.save_for_backward(saved_a, saved_b)
# This code adds training support for the operator. You must provide us
# the backward formula for the operator and a `setup_context` function
# to save values to be used in the backward.
torch.library.register_autograd(
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)
测试算子#
使用 torch.library.opcheck 测试自定义算子是否已正确注册。请注意,此函数并不测试梯度在数学上是否正确——请计划单独编写测试,无论是手动编写还是使用 torch.autograd.gradcheck。
def sample_inputs(device, *, requires_grad=False):
def make_tensor(*size):
return torch.randn(size, device=device, requires_grad=requires_grad)
def make_nondiff_tensor(*size):
return torch.randn(size, device=device, requires_grad=False)
return [
[make_tensor(3), make_tensor(3), 1],
[make_tensor(20), make_tensor(20), 3.14],
[make_tensor(20), make_nondiff_tensor(20), -123],
[make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3],
]
def reference_muladd(a, b, c):
return a * b + c
samples = sample_inputs(device, requires_grad=True)
samples.extend(sample_inputs(device, requires_grad=False))
for args in samples:
# Correctness test
result = torch.ops.extension_cpp.mymuladd(*args)
expected = reference_muladd(*args)
torch.testing.assert_close(result, expected)
# Use opcheck to check for incorrect usage of operator registration APIs
torch.library.opcheck(torch.ops.extension_cpp.mymuladd.default, args)
创建可修改算子#
您可能希望编写一个修改其输入的自定义算子。请使用 Tensor(a!) 在模式中指定每个可修改张量;否则,将出现未定义行为。如果有多个被修改的张量,请为每个可修改张量使用不同的名称(例如 Tensor(a!), Tensor(b!), Tensor(c!))。
让我们编写一个 myadd_out(a, b, out) 算子,它将 a+b 的内容写入 out 中。
void myadd_out_cpu(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
torch::stable::Tensor& out) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(b.sizes().equals(out.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU);
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = out.mutable_data_ptr<float>();
for (int64_t i = 0; i < out.numel(); i++) {
result_ptr[i] = a_ptr[i] + b_ptr[i];
}
}
在定义算子时,我们必须在模式中指定它修改了 out 张量。
STABLE_TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
m.def("mymul(Tensor a, Tensor b) -> Tensor");
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu));
m.impl("mymul", TORCH_BOX(&mymul_cpu));
m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu));
}
void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(b.sizes() == out.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_CHECK(out.dtype() == at::kFloat);
TORCH_CHECK(out.is_contiguous());
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = out.data_ptr<float>();
for (int64_t i = 0; i < out.numel(); i++) {
result_ptr[i] = a_ptr[i] + b_ptr[i];
}
}
在定义算子时,我们必须在模式中指定它修改了 out 张量。
TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
m.def("mymul(Tensor a, Tensor b) -> Tensor");
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
}
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", &mymuladd_cpu);
m.impl("mymul", &mymul_cpu);
m.impl("myadd_out", &myadd_out_cpu);
}
注意
不要将任何被修改的张量作为算子的输出返回,因为这会导致与 torch.compile 等 PyTorch 子系统不兼容。
结论#
在本教程中,我们介绍了将自定义 C++ 和 CUDA 算子与 PyTorch 集成的推荐方法。TORCH_LIBRARY/STABLE_TORCH_LIBRARY 和 torch.library API 是相当底层的。有关如何使用该 API 的更多信息,请参阅 自定义算子手册。