• 文档 >
  • 通过 Triton 实现自定义 GPU 内核
快捷方式

通过 Triton 实现自定义 GPU 内核

PyTorch/XLA 现在支持 Triton 内核,可在 GPU 上实现高性能深度学习模型执行。Triton 是一种专门用于 GPU 编程的语言和编译器,使开发人员能够编写自定义内核,充分利用 GPU 的全部潜力来处理深度学习模型中的各种操作。

给定如下定义的 Triton 内核

@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
    # NOTE: `constexpr` so it can be used as a shape value.
):
  # Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28
  pid = tl.program_id(axis=0)
  block_start = pid * BLOCK_SIZE
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
  mask = offsets < n_elements
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

我们可以如下将其作为 PyTorch/XLA 执行图的一部分来运行

import torch

import torch_xla.experimental.triton as xla_triton
import torch_xla

import triton
import triton.language as tl

size = 16
x = torch.arange(size, dtype=torch.int64).to("xla")
y = torch.arange(size, dtype=torch.int64).to("xla")
output = torch.empty_like(x)
block_size = 8
grid = (triton.cdiv(size, block_size),)

# triton_call takes the same arguments as the triton.jit function, in addition
to the kernel itself and the grid that is used to execute the kernel.
All the tl.constexpr terms are passed as kwargs at the end.
payload = xla_triton.triton_call(
    x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)

# To make the triton kernel, a part of the PyTorch/XLA graph, we create a
# custom call node with the expected inputs, payload from triton_call,
# the output shapes and output dtypes. The payload already contains information
# regarding how the GPU buffers will be loaded when this node is executed.
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
                                                [output.shape], [torch.int64])

对于更复杂的内核,您还可以参考 PyTorch/XLA 中的 Triton Flash Attention 内核测试。

依赖项

Triton 集成依赖于 triton 包才能正常运行。此代码已通过 triton==2.3.0 进行测试。安装方法

pip install --no-deps triton==2.3.0

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源