注意
跳转至页面底部下载完整示例代码。
将用户定义的 Triton 核函数与 torch.compile 结合使用#
创建日期:2024 年 4 月 19 日 | 最后更新:2026 年 4 月 29 日 | 最后验证:2024 年 11 月 5 日
作者: Oguz Ulgen
用户定义的 Triton 核函数可用于优化模型计算的特定部分。这些核函数使用 Triton 语言编写,旨在更轻松地达到硬件性能峰值。通过将用户定义的 Triton 核函数与 torch.compile 结合使用,您可以将这些经过优化的计算集成到 PyTorch 模型中,从而可能获得显著的性能提升。
本指南演示了如何将用户定义的 Triton 核函数与 torch.compile 结合使用。
先决条件#
在开始此秘籍之前,请确保您已具备以下条件
对
torch.compile和 Triton 的基本了解。请参阅PyTorch 2.3 或更高版本
支持 Triton 的 GPU
import torch
from torch.utils._triton import has_triton
基本用法#
在此示例中,我们将使用 Triton 文档中的一个简单向量加法核函数与 torch.compile 配合使用。参考信息请见 Triton 文档。
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X: tensor([-2.4273, -0.1827, 0.2075, 0.5945], device='cuda:0')
Y: tensor([-0.9062, 0.0231, -2.0174, -0.1097], device='cuda:0')
is equal to
tensor([-3.3335, -0.1596, -1.8099, 0.4848], device='cuda:0')
高级用法#
Triton 的自动调优(autotune)功能是一个强大的工具,可以自动优化 Triton 核函数的配置参数。它会探索一系列可能的配置,并选择最适合您特定用例的配置,以提供最佳性能。
当与 torch.compile 一起使用时,triton.autotune 可以帮助确保您的 PyTorch 模型尽可能高效地运行。以下是结合使用 torch.compile 和 triton.autotune 的示例。
注意
torch.compile 仅支持 triton.autotune 的 configs、keys、restore_value 和 reset_to_zero 参数。
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](x, y, output, n_elements)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X: tensor([ 0.0443, -1.9927, 1.0136, 1.5370], device='cuda:0')
Y: tensor([-0.9129, -0.2742, 1.1992, -1.1064], device='cuda:0')
is equal to
tensor([-0.8686, -2.2669, 2.2128, 0.4306], device='cuda:0')
可组合性#
用户定义的 Triton 核函数并不会自动支持所有 PyTorch 子系统。这体现在以下使用场景中:
添加 CPU 后备方案(CPU fallback)
添加
FlopCounter公式与张量子类(Tensor Subclasses)组合
要与更多的 PyTorch 子系统组合,请使用 torch.library.triton_op。
triton_op 是一种结构化的方式,用于定义由一个或多个 Triton 核函数支持的自定义算子:就像常规自定义算子(torch.library.custom_op)一样,您可以通过 torch.library 指定与 PyTorch 子系统的交互。然而,与 torch.library.custom_op 不同(它对于 torch.compile 来说是黑盒调用),torch.compile 会追踪进入 triton_op 以应用优化。
以下是将 Triton 核函数集成到 PyTorch 时应使用的 API 对照表。
Triton 核函数(无显式 |
|
|
|
|---|---|---|---|
支持推理 |
是 |
是 |
是 |
支持训练 |
大多数情况下 |
是 |
是 |
支持 |
是 |
是 |
是 |
支持 |
大多数情况下 |
大多数情况下 |
所有情况下 |
|
是 |
是 |
否 |
支持 AOTInductor |
是 |
是 |
否 |
支持 PyTorch 子系统,如 FlopCounterMode、CPU Fallback、张量子类 |
否 |
是 |
是 |
使用 triton_op 包装 Triton 核函数#
使用 torch.library.triton_op 来包装可能调用一个或多个 Triton 核函数的函数。使用 torch.library.wrap_triton 来包装对 Triton 核函数的调用。
from torch.library import triton_op, wrap_triton
@triton_op("mylib::mysin", mutates_args={})
def mysin(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.sin(x)
tl.store(out_ptr + offsets, output, mask=mask)
您可以通过以下两种方式之一调用 triton_op。
x = torch.randn(3, device="cuda")
y = mysin(x)
z = torch.ops.mylib.mysin.default(x)
assert torch.allclose(y, x.sin())
assert torch.allclose(z, x.sin())
生成的 triton_op 可与 torch.compile 和 AOTInductor 一起使用。
y = torch.compile(mysin)(x)
assert torch.allclose(y, x.sin())
添加训练支持#
使用 register_autograd 为 triton_op 添加自动微分(autograd)公式。优先使用此方法,而不是 torch.autograd.Function(后者在与 torch.compile 组合时有各种兼容性隐患)。
注意,反向传播必须是 PyTorch 可识别算子的组合。如果您希望反向传播也调用 Triton 核函数,那么这些函数也必须包装在 triton_op 中。
@triton.jit
def cos_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.cos(x)
tl.store(out_ptr + offsets, output, mask=mask)
@triton_op("mylib::mycos", mutates_args={})
def mycos(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad * mycos(x)
def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)
mysin.register_autograd(backward, setup_context=setup_context)
添加 CPU 后备方案#
Triton 核函数无法在 CPU 上运行。使用 register_kernel 为 triton_op 添加 CPU(或任何其他设备)后备方案。
后备方案必须由 PyTorch 算子组成。
添加 FlopCounter 公式#
要指定 Triton 核函数在 PyTorch 的浮点运算计数器(flop counter)中报告的浮点运算数,请使用 register_flop_formula。
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
@register_flop_formula(torch.ops.mylib.mysin)
def _(x_shape):
numel = 1
for s in x_shape:
numel *= s
return numel
x = torch.randn(3, device="cuda")
FlopCounterMode 需要 tabulate 库。在运行以下代码之前,请确保已安装 tabulate,或通过运行 pip install tabulate 进行安装。
限制#
截至 PyTorch 2.3,torch.compile 中对用户定义 Triton 核函数的支持包括动态形状、torch.autograd.Function、JIT Inductor 和 AOT Inductor。您可以结合使用这些功能来构建复杂的高性能模型。
PyTorch 2.6 增加了 torch.library.triton_op,增加了对张量子类中用户定义 Triton 核函数的支持以及其他高级功能。
然而,有一些限制需要注意:
Triton 特性: 虽然
triton.heuristics可以独立使用,也可以在triton.autotune之前使用,但不能在triton.autotune之后使用。这意味着如果要同时使用triton.heuristics和triton.autotune,必须先使用triton.heuristics。
结论#
在本指南中,我们探讨了如何利用用户定义的 Triton 核函数与 torch.compile 配合使用。我们深入介绍了简单向量加法核函数的基本用法,以及涉及 Triton 自动调优功能的高级用法。我们还讨论了用户定义 Triton 核函数与其他 PyTorch 功能的可组合性,并强调了当前的一些限制。
另请参阅#
脚本运行总时长:(0 分 3.198 秒)