• 文档 >
  • 为自定义内核自动生成插件
快捷方式

为自定义内核自动生成插件

我们将演示如何使用 Torch-TensorRT,利用 TensorRT 10.7 中基于 Python 的新插件系统,为自定义内核自动生成插件。

在 Torch-TensorRT 不知道如何将其编译到 TensorRT 的情况下,它支持回退到 PyTorch 的操作实现。然而,这会带来图中断的代价,并会降低模型的性能。解决操作支持不足的最简单方法是添加一个分解(请参阅:为 Dynamo 前端编写降级通道)——它用 Torch-TensorRT 支持的 PyTorch 操作来定义该操作,或者添加一个转换器(请参阅:为 Dynamo 前端编写转换器)——它用 TensorRT 操作来定义该操作。

在某些情况下,这两种方法都不是很好的选择,也许是因为该操作是一个不属于标准 PyTorch 的自定义核函数,或者 TensorRT 无法原生支持它。

对于这些情况,可以使用 TensorRT 插件来替换 TensorRT 引擎**内部**的操作,从而避免图中断带来的性能和资源开销。

以前,这不仅涉及构建一个高性能内核的复杂过程,还包括将其设置为在 TensorRT 中运行(请参阅:在 TensorRT 引擎中使用带有 Torch-TensorRT 的自定义内核)。从 TensorRT 10.7 开始,有了一个新的 Python 原生插件系统,极大地简化了这一过程。该插件系统还允许 Torch-TensorRT 自动生成必要的转换代码,将 PyTorch 中的操作转换为 TensorRT。

在 PyTorch 中编写自定义操作

之前的教程已经涵盖了在 PyTorch 中创建自定义操作,这些操作随后会与 Torch-TensorRT 一起使用。

在这里,我们用 Triton 定义了一个简单的逐元素乘法运算符。然后,该运算符在 PyTorch 中注册为自定义操作,包括其主机启动代码和一个“元内核”。元内核是一个描述该操作将执行的形状和数据类型转换的函数。Dynamo 和 Torch-TensorRT 会使用这个元内核,因此定义它是必要的。

from typing import Tuple

import tensorrt_bindings.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl


@triton.jit
def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    # Compute the range of elements that this thread block will work on
    block_start = pid * BLOCK_SIZE
    # Range of indices this thread will handle
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Load elements from the X and Y tensors
    x_vals = tl.load(X + offsets)
    y_vals = tl.load(Y + offsets)
    # Perform the element-wise multiplication
    z_vals = x_vals * y_vals * a + b
    # Store the result in Z
    tl.store(Z + offsets, z_vals)


@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=())  # type: ignore[misc]
def elementwise_scale_mul(
    X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
) -> torch.Tensor:
    # Ensure the tensors are on the GPU
    assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
    assert X.shape == Y.shape, "Tensors must have the same shape."

    # Create output tensor
    Z = torch.empty_like(X)

    # Define block size
    BLOCK_SIZE = 1024

    # Grid of programs
    grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)

    # Launch the kernel with parameters a and b
    elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE)

    return Z

逐元素操作的元内核只是其中一个输入的形状和数据类型,因为在操作过程中我们不会改变形状。

@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul")
def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
    return x

在这里,我们使用 Torch-TensorRT 中的自动插件创建功能,该功能通过 TensorRT QDP API 实现插件注册。

torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
    "torchtrt_ex::elementwise_scale_mul"
)


# # %%
# # Generating the Converter
# # -------------------------------------------------------------------
# # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.
# # As long as the namespace and names match, the following function will automatically generate the converter for the operation.
# # If plugins require an output allocator to dynamically allocate output buffers, like data dependent operators, please set requires_output_allocator to True.
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
    "torchtrt_ex::elementwise_scale_mul",
    supports_dynamic_shapes=True,
    requires_output_allocator=False,
)


# # %%
# # Above two commands can be replaced with the following single one line:
# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True, requires_output_allocator=False)

在模型中使用我们的转换器

现在我们可以在模型中使用我们的自定义运算符,并用 Torch-TensorRT 进行编译。我们可以看到,自定义运算符被用作模型前向传播中的一个操作。此时编译模型的过程与标准的 Torch-TensorRT 用法完全相同。

class MyModel(torch.nn.Module):  # type: ignore[misc]
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        z = torch.add(x, y)
        res = torch.ops.torchtrt_ex.elementwise_scale_mul.default(x, z, b=0.5)

        return res


my_model = MyModel().to("cuda")
m = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float)
n = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float)

with torch_tensorrt.logging.errors():
    model_trt = torch_tensorrt.compile(my_model, inputs=[m, n], min_block_size=1)
    for i in range(300):
        res = model_trt(m, n)
        assert torch.allclose(res, my_model(m, n))

print("Ran with custom plugin!")

脚本总运行时间: ( 0 分 0.000 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源