注意
转到末尾 下载完整的示例代码
自动生成自定义内核的转换器¶
我们将演示如何使用 Torch-TensorRT 和 TensorRT 10.8 中基于 Python 的新插件系统自动生成自定义内核的转换器。
如果 Torch-TensorRT 不知道如何编译某些操作,它会回退到 PyTorch 的实现。然而,这会导致图中断,并会降低模型的性能。解决操作不支持的最简单方法是添加一个分解(请参阅:为 Dynamo 前端编写降低通道)——它根据 Torch-TensorRT 中支持的 PyTorch 操作来定义操作,或者添加一个转换器(请参阅:为 Dynamo 前端编写转换器)——它根据 TensorRT 操作来定义操作。
在某些情况下,这两种方法都不是很好的选择,也许是因为该操作是一个不属于标准 PyTorch 的自定义核函数,或者 TensorRT 无法原生支持它。
对于这些情况,可以使用 TensorRT 插件替换 TensorRT 引擎**内部**的操作符,从而避免图中断带来的性能和资源开销。
以前,这涉及到不仅要构建一个高性能内核,还要设置它以便在 TensorRT 中运行(参见:在 Torch-TensorRT 中使用 TensorRT 引擎内的自定义内核)的复杂过程。借助 TensorRT 10.8,出现了一个新的 Python 原生插件系统,极大地简化了这一过程。这个插件系统还允许 Torch-TensorRT 自动生成必要的转换代码,以将 PyTorch 中的操作转换为 TensorRT。
在 PyTorch 中编写自定义操作¶
之前的教程已经涵盖了在 PyTorch 中创建自定义操作符,这些操作符稍后将与 Torch-TensorRT 一起使用。
这里我们用 Triton 定义一个简单的元素级乘法操作。然后将这个操作符注册为 PyTorch 中的自定义操作符,并附带其主机启动代码以及一个“元内核”。元内核是一个描述操作符将执行的形状和数据类型转换的函数。这个元内核被 Dynamo 和 Torch-TensorRT 使用,因此有必要定义它。
from typing import Tuple
import tensorrt.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl
@triton.jit
def elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
# Program ID determines the block of data each thread will process
pid = tl.program_id(0)
# Compute the range of elements that this thread block will work on
block_start = pid * BLOCK_SIZE
# Range of indices this thread will handle
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Load elements from the X and Y tensors
x_vals = tl.load(X + offsets)
y_vals = tl.load(Y + offsets)
# Perform the element-wise multiplication
z_vals = x_vals * y_vals
# Store the result in Z
tl.store(Z + offsets, z_vals)
@torch.library.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) # type: ignore[misc]
def elementwise_mul(
X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
assert X.shape == Y.shape, "Tensors must have the same shape."
# Create output tensor
Z = torch.empty_like(X)
# Define block size
BLOCK_SIZE = 1024
# Grid of programs
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
# Launch the kernel
elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)
return Z
元素级操作的元内核就是其中一个输入的形状和数据类型,因为我们在操作过程中不会改变形状。
@torch.library.register_fake("torchtrt_ex::elementwise_mul")
def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
return x
使用快速部署插件系统为 TensorRT 编写插件¶
TensorRT 10.8 中的快速部署插件系统允许使用显著更少的样板代码在 Python 中创建自定义插件。它使用与 PyTorch 类似的系统,您可以在其中定义一个描述运算符将执行的形状和数据类型转换的函数,然后定义给定 GPU 内存句柄启动内核的代码。
就像 PyTorch 元内核一样,输入和输出之间没有形状或数据类型的转换,所以我们可以直接告诉 TensorRT 期望与我们得到的相同的形状
@trtp.register("torchtrt_ex::elementwise_mul")
def _(
x: trtp.TensorDesc, y: trtp.TensorDesc, b: float, a: int
) -> Tuple[trtp.TensorDesc]:
return x.like()
这里我们重用与 PyTorch 相似的主机启动代码,但在启动内核之前我们需要将 TensorRT 张量转换为 PyTorch 张量。这些操作也是原地操作,因此结果必须放在 TensorRT 提供的输出张量中。
@trtp.impl("torchtrt_ex::elementwise_mul")
def _(
x: trtp.Tensor,
y: trtp.Tensor,
b: float,
a: int,
outputs: Tuple[trtp.Tensor],
stream: int,
):
# Define block size
BLOCK_SIZE = 1024
# Grid of programs
grid = lambda meta: (x.numel() // meta["BLOCK_SIZE"],)
x_t = torch.as_tensor(x, device="cuda")
y_t = torch.as_tensor(y, device="cuda")
z_t = torch.as_tensor(outputs[0], device="cuda")
# Launch the kernel
elementwise_mul_kernel[grid](x_t, y_t, z_t, BLOCK_SIZE=BLOCK_SIZE)
生成转换器¶
鉴于我们已经在 PyTorch 和 TensorRT 中定义了自定义运算符,我们现在可以为该运算符生成转换器。只要命名空间和名称匹配,以下函数将自动为该运算符生成转换器。
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True
)
将我们的转换器与模型一起使用¶
现在我们可以在模型中使用自定义运算符,并用 Torch-TensorRT 编译它。我们可以看到自定义运算符在模型的前向传播中被用作其中一个操作。此时编译模型的过程与标准 Torch-TensorRT 用法相同。
class MyModel(torch.nn.Module): # type: ignore[misc]
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = torch.add(x, y)
res = torch.ops.torchtrt_ex.elementwise_mul.default(x, z, a=1)
return res
my_model = MyModel().to("cuda")
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
n = torch.full((64, 64), 3, device="cuda", dtype=torch.float)
with torch_tensorrt.logging.errors():
model_trt = torch_tensorrt.compile(my_model, inputs=[m, n], min_block_size=1)
for i in range(300):
res = model_trt(m, n)
assert torch.allclose(res, my_model(m, n))
print("Ran with custom plugin!")
脚本总运行时间: ( 0 分 0.000 秒)