• 文档 >
  • 通过 Pallas 实现自定义内核
快捷方式

通过 Pallas 进行自定义内核

随着 OpenAI Triton 的兴起,自定义内核在 GPU 社区中越来越受欢迎,例如 FlashAttentionPagedAttention 的引入。为了在 TPU 世界中实现功能对等,Google 推出了 Pallas。为了让 PyTorch/XLA 继续推动 TPU 的性能,我们必须支持自定义内核,而最佳方式就是通过 Pallas。

假设您有一个 Pallas 内核定义如下

from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(add_vectors_kernel,
                        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
                        )(x, y)

需要注意的是,在导入任何 jax 模块之前,非常重要的一点是运行 jax_import_guard()。否则,程序将在 TPU 上挂起,因为 jax 会锁定 TPU,而 torch-xla 无法访问它。

采用上述内核以兼容 PyTorch/XLA

使用示例

q = torch.randn(3, 2, 128, 4).to('xla')
k = torch.randn(3, 2, 128, 4).to('xla')
v = torch.randn(3, 2, 128, 4).to('xla')

# Adopts any Pallas kernel
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k)

对于简单的内核,采用非常简单,只需一行代码。对于更复杂的内核,您可以参考我们的 Flash Attention 实现以获取详细信息。

使用内置内核

除了手动包装外部 Pallas 内核外,还有一些内置内核,PyTorch/XLA 已经完成了它们的采用。这些内置内核可以像其他 torch.ops 一样使用。当前支持的内置内核有:- FlashAttention - PagedAttention

FlashAttention

用法示例

# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = flash_attention(q, k, v)

集成示例

我们在训练测试脚本中有一个 FlashAttention 集成示例

PagedAttention

示例用法

# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = torch.ops.xla.paged_attention(
    query.squeeze(dim=1),
    key_cache,
    value_cache,
    context_lens,
    block_tables,
    pages_per_compute_block,
    megacore_mode=None,
)

集成示例

vLLM TPU 集成利用 PagedAttention 进行有效的 KV 缓存内存管理。

依赖项

Pallas 集成依赖 JAX 来运行。但是,并非所有 JAX 版本都与您安装的 PyTorch/XLA 兼容。要安装正确的 JAX

pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

编写自己的 Pallas 内核

您可以在 https://jax.net.cn/en/latest/pallas/index.html 上找到有关如何编写 Pallas 内核的权威指南。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源