评价此页

自定义 SYCL 运算符#

您将学到什么
  • 如何将 SYCL 编写的自定义运算符集成到 PyTorch 中

先决条件
  • PyTorch 2.8 或更高版本

  • SYCL 编程基础知识

注意

SYCL 是英特尔 GPU(设备标签 xpu)的后端编程语言。有关配置详情,请参阅:在英特尔 GPU 上入门。英特尔编译器随英特尔深度学习精粹一起提供,负责 SYCL 编译。请确保在执行本教程中的代码示例之前安装并激活编译器环境。

PyTorch 提供了大量可在张量上运行的运算符(例如 torch.add、torch.sum 等)。然而,您可能希望为 PyTorch 添加新的自定义运算符。本教程演示了编写 SYCL 编写的自定义运算符的最佳途径。C++ 和 CUDA 运算符的教程可在 自定义 C++ 和 CUDA 运算符 中找到。

遵循此结构来创建自定义 SYCL 运算符

sycl_example/
├── setup.py
├── sycl_extension
│   ├── __init__.py
│   ├── muladd.sycl
│   └── ops.py
└── test_sycl_extension.py

设置构建系统#

如果您需要编译 **SYCL** 代码(例如,.sycl 文件),请使用 torch.utils.cpp_extension.SyclExtension。设置过程与 C++/CUDA 非常相似,只是编译参数需要针对 SYCL 进行调整。

使用 sycl_extension 就像编写以下 setup.py 一样简单

import os
import torch
import glob
from setuptools import find_packages, setup
from torch.utils.cpp_extension import SyclExtension, BuildExtension

library_name = "sycl_extension"
py_limited_api = True
extra_compile_args = {
    "cxx": ["-O3",
            "-fdiagnostics-color=always",
            "-DPy_LIMITED_API=0x03090000"],
    "sycl": ["-O3" ]
}

assert(torch.xpu.is_available()), "XPU is not available, please check your environment"
# Source files collection
this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, library_name)
sources = list(glob.glob(os.path.join(extensions_dir, "*.sycl")))
# Construct extension
ext_modules = [
    SyclExtension(
        f"{library_name}._C",
        sources,
        extra_compile_args=extra_compile_args,
        py_limited_api=py_limited_api,
    )
]
setup(
    name=library_name,
    packages=find_packages(),
    ext_modules=ext_modules,
    install_requires=["torch"],
    description="Simple Example of PyTorch Sycl extensions",
    cmdclass={"build_ext": BuildExtension},
    options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
)

定义自定义运算符并添加后端实现#

首先,让我们编写一个计算 mymuladd 的 SYCL 函数

为了从 PyTorch 的 Python 前端使用它,我们需要使用 TORCH_LIBRARY API 将其注册为 PyTorch 运算符。这将自动将运算符绑定到 Python。

如果您也有 myaddmul 的 SYCL 实现,您也可以在单独的 TORCH_LIBRARY_IMPL 块中注册它

#include <c10/xpu/XPUStream.h>
#include <sycl/sycl.hpp>
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>

namespace sycl_extension {
// MulAdd Kernel: result = a * b + c
static void muladd_kernel(
    int numel, const float* a, const float* b, float c, float* result,
    const sycl::nd_item<1>& item) {
    int idx = item.get_global_id(0);
    if (idx < numel) {
        result[idx] = a[idx] * b[idx] + c;
    }
}

class MulAddKernelFunctor {
public:
    MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
        : numel(_numel), a(_a), b(_b), c(_c), result(_result) {}
    void operator()(const sycl::nd_item<1>& item) const {
        muladd_kernel(numel, a, b, c, result, item);
    }

private:
    int numel;
    const float* a;
    const float* b;
    float c;
    float* result;
};

at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
    TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
    TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
    TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
    TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
    TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");

    at::Tensor a_contig = a.contiguous();
    at::Tensor b_contig = b.contiguous();
    at::Tensor result = at::empty_like(a_contig);

    const float* a_ptr = a_contig.data_ptr<float>();
    const float* b_ptr = b_contig.data_ptr<float>();
    float* res_ptr = result.data_ptr<float>();
    int numel = a_contig.numel();

    sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
    constexpr int threads = 256;
    int blocks = (numel + threads - 1) / threads;

    queue.submit([&](sycl::handler& cgh) {
        cgh.parallel_for<MulAddKernelFunctor>(
            sycl::nd_range<1>(blocks * threads, threads),
            MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
        );
    });

    return result;
}
// Defines the operators
TORCH_LIBRARY(sycl_extension, m) {
  m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
}

// ==================================================
// Register SYCL Implementations to Torch Library
// ==================================================
TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) {
    m.impl("mymuladd", &mymuladd_xpu);
}

} // namespace sycl_extension

创建 Python 接口#

sycl_extension/ops.py 文件中为我们的运算符创建一个 Python 接口

import torch
from torch import Tensor
__all__ = ["mymuladd"]

def mymuladd(a: Tensor, b: Tensor, c: float) -> Tensor:
    """Performs a * b + c in an efficient fused kernel"""
    return torch.ops.sycl_extension.mymuladd.default(a, b, c)

初始化包#

创建 sycl_extension/__init__.py 文件以使该包可导入

import ctypes
from pathlib import Path

import torch

current_dir = Path(__file__).parent.parent
build_dir = current_dir / "build"
so_files = list(build_dir.glob("**/*.so"))

assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"

with torch._ops.dl_open_guard():
    loaded_lib = ctypes.CDLL(so_files[0])

from . import ops

__all__ = [
    "loaded_lib",
    "ops",
]

测试 SYCL 扩展运算符#

使用简单的测试来验证运算符是否正常工作。

import torch
from torch.testing._internal.common_utils import TestCase
import unittest
import sycl_extension

def reference_muladd(a, b, c):
    return a * b + c

class TestMyMulAdd(TestCase):
    def sample_inputs(self, device, *, requires_grad=False):
        def make_tensor(*size):
            return torch.randn(size, device=device, requires_grad=requires_grad)

        def make_nondiff_tensor(*size):
            return torch.randn(size, device=device, requires_grad=False)

        return [
            [make_tensor(3), make_tensor(3), 1],
            [make_tensor(20), make_tensor(20), 3.14],
            [make_tensor(20), make_nondiff_tensor(20), -123],
            [make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3],
        ]

    def _test_correctness(self, device):
        samples = self.sample_inputs(device)
        for args in samples:
            result = sycl_extension.ops.mymuladd(*args)
            expected = reference_muladd(*args)
            torch.testing.assert_close(result, expected)

    @unittest.skipIf(not torch.xpu.is_available(), "requires Intel GPU")
    def test_correctness_xpu(self):
        self._test_correctness("xpu")

if __name__ == "__main__":
    unittest.main()

此测试通过将其输出与参考实现进行比较来检查自定义运算符的正确性。

结论#

在本教程中,我们演示了如何为 PyTorch 实现和编译自定义 SYCL 运算符。我们特别展示了一个推理操作 muladd。有关添加向后支持或启用 torch.compile 兼容性,请参阅 自定义 C++ 和 CUDA 运算符