评价此页

PyTorch 自定义算子#

创建日期:2024年6月18日 | 最后更新:2025年7月31日 | 最后验证:2024年11月5日

PyTorch 提供了大量的算子库,可用于处理张量(例如 torch.add, torch.sum 等)。然而,您可能希望将一个新的自定义算子引入 PyTorch,并使其能够与 torch.compile、autograd 和 torch.vmap 等子系统协同工作。为此,您必须通过 Python torch.library 文档 或 C++ TORCH_LIBRARY API 将自定义算子注册到 PyTorch 中。

使用 Python 编写自定义算子#

请参阅 自定义 Python 算子

在以下情况下,您可能希望使用 Python(而非 C++)来编写自定义算子:

  • 您有一个 Python 函数,希望 PyTorch 将其视为不透明的可调用对象(opaque callable),特别是针对 torch.compiletorch.export 的情况。

  • 您拥有 C++/CUDA 内核的 Python 绑定,并希望这些绑定能与 PyTorch 子系统(如 torch.compiletorch.autograd)相结合。

  • 您正在使用 Python(而非像 AOTInductor 这样仅限 C++ 的环境)。

将自定义 C++ 和/或 CUDA 代码集成到 PyTorch#

请参阅 自定义 C++ 和 CUDA 算子

注意

SYCL 是 Intel GPU 的后端编程语言。集成自定义 Sycl 代码请参考 自定义 SYCL 算子

在以下情况下,您可能希望使用 C++(而非 Python)来编写自定义算子:

  • 您拥有自定义的 C++ 和/或 CUDA 代码。

  • 您计划将此代码与 AOTInductor 配合使用,以实现无 Python 环境的推理。

自定义算子手册#

对于教程和本页未涵盖的信息,请参阅 自定义算子手册(我们正致力于将这些信息迁移到我们的文档网站)。我们建议您先阅读上述其中一篇教程,然后将自定义算子手册作为参考资料使用;该手册并非旨在从头到尾完整阅读。

我应该在何时创建自定义算子?#

如果您的操作可以通过内置 PyTorch 算子的组合来表达,请将其编写为 Python 函数并直接调用,而不要创建自定义算子。如果您正在调用 PyTorch 无法理解的库(例如自定义 C/C++ 代码、自定义 CUDA 内核或 C/C++/CUDA 扩展的 Python 绑定),请使用算子注册 API 来创建自定义算子。

我为什么要创建自定义算子?#

虽然可以通过获取张量的数据指针并将其传递给 pybind 绑定的内核来使用 C/C++/CUDA 内核,但这种方法无法与 autograd, torch.compile, vmap 等 PyTorch 子系统协同工作。为了使一个操作能够与 PyTorch 子系统相结合,必须通过算子注册 API 进行注册。