评价此页

PyTorch 自定义运算符#

创建于:2024 年 6 月 18 日 | 最后更新:2025 年 7 月 31 日 | 最后验证:2024 年 11 月 5 日

PyTorch 提供了大量的运算符库,可用于 Tensor(例如 torch.addtorch.sum 等)。但是,您可能希望将新的自定义操作引入 PyTorch,并使其与 torch.compile、autograd 和 torch.vmap 等子系统协同工作。为此,您必须通过 Python torch.library 文档 或 C++ TORCH_LIBRARY API 将自定义操作注册到 PyTorch。

从 Python 编写自定义运算符#

请参阅 自定义 Python 运算符

如果您满足以下条件,您可能希望从 Python(而不是 C++)编写自定义运算符:

  • 您有一个 Python 函数,希望 PyTorch 将其视为不透明的可调用对象,尤其是在 torch.compiletorch.export 方面。

  • 您有一些 Python 绑定到 C++/CUDA 内核,并希望这些内核与 PyTorch 子系统(如 torch.compiletorch.autograd)结合使用

  • 您正在使用 Python(而不是像 AOTInductor 这样的纯 C++ 环境)。

将自定义 C++ 和/或 CUDA 代码与 PyTorch 集成#

请参阅 自定义 C++ 和 CUDA 运算符

注意

SYCL 作为 Intel GPU 的后端编程语言。集成自定义 Sycl 代码请参阅 自定义 SYCL 运算符

如果您满足以下条件,您可能希望从 C++(而不是 Python)编写自定义运算符:

  • 您有自定义 C++ 和/或 CUDA 代码。

  • 您计划将此代码与 AOTInductor 结合使用以进行无 Python 推理。

自定义运算符手册#

有关教程和本页未涵盖的信息,请参阅 自定义运算符手册(我们正在努力将信息迁移到我们的文档站点)。我们建议您首先阅读上述教程之一,然后将自定义运算符手册作为参考;它不适合从头到尾阅读。

我应该何时创建自定义运算符?#

如果您的操作可以表示为内置 PyTorch 运算符的组合,那么请将其编写为 Python 函数并调用它,而不是创建自定义运算符。如果您要调用 PyTorch 无法理解的库(例如,自定义 C/C++ 代码、自定义 CUDA 内核或 C/C++/CUDA 扩展的 Python 绑定),请使用运算符注册 API 创建自定义运算符。

我为什么要创建自定义运算符?#

可以通过获取 Tensor 的数据指针并将其传递给 pybind'ed 内核来使用 C/C++/CUDA 内核。但是,这种方法不适用于 PyTorch 子系统,例如 autograd、torch.compile、vmap 等。为了使操作与 PyTorch 子系统结合使用,它必须通过运算符注册 API 进行注册。