• 文档 >
  • 使用 Torch-TensorRT 处理动态形状
快捷方式

使用 Torch-TensorRT 实现动态形状

默认情况下,您可以使用不同的输入形状运行 PyTorch 模型,输出形状是即时确定的。然而,Torch-TensorRT 是一个 AOT(Ahead-of-Time,预先)编译器,它需要一些关于输入形状的先验信息来编译和优化模型。

使用 torch.export (AOT) 实现动态形状

在动态输入形状的情况下,我们必须提供 (min_shape, opt_shape, max_shape) 参数,以便模型可以针对这个输入形状范围进行优化。静态和动态形状的用法示例如下。

注意:以下代码使用 Dynamo 前端。如果使用 Torchscript 前端,请将 ir=dynamo 替换为 ir=ts,其行为完全相同。

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
# Compile with static shapes
inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)
# or compile with dynamic shapes
inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
                              opt_shape=[4, 3, 224, 224],
                              max_shape=[8, 3, 224, 224],
                              dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)

底层原理

当我们使用 torch_tensorrt.compile API 并设置 ir=dynamo(默认值)时,编译过程分为两个阶段。

  • torch_tensorrt.dynamo.trace(使用 torch.export 和给定的输入来追踪计算图)

我们使用 torch.export.export() API 来追踪和导出一个 PyTorch 模块为 torch.export.ExportedProgram。对于动态形状的输入,通过 torch_tensorrt.Input API 提供的 (min_shape, opt_shape, max_shape) 范围用于构建 torch.export.Dim 对象,该对象用于 export API 的 dynamic_shapes 参数。请查看 _tracer.py 文件以了解其底层工作原理。

  • torch_tensorrt.dynamo.compile(使用 TensorRT 编译一个 torch.export.ExportedProgram 对象)

在转换为 TensorRT 的过程中,计算图的节点元数据中已经包含了动态形状信息,这些信息将在引擎构建阶段使用。

自定义动态形状约束

给定一个输入 x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype),Torch-TensorRT 会在 torch.export 追踪期间,通过根据提供的动态维度构建 torch.export.Dim 对象,来尝试自动设置约束。有时,我们可能需要设置额外的约束,如果不指定,TorchDynamo 就会报错。如果您需要为模型设置任何自定义约束(通过使用 torch.export.Dim),我们建议您在用 Torch-TensorRT 编译之前先导出您的程序。请参考此文档来导出具有动态形状的 PyTorch 模块。这里有一个简单的示例,导出一个对动态维度有某些限制的 matmul 层。

import torch
import torch_tensorrt

class MatMul(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query, key):
        attn_weight = torch.matmul(query, key.transpose(-1, -2))
        return attn_weight

model = MatMul().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
seq_len = torch.export.Dim("seq_len", min=1, max=10)
dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
exp_program = torch.export.export(model, tuple(inputs), dynamic_shapes=dynamic_shapes)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
trt_gm(*inputs)

使用 torch.compile (JIT) 实现动态形状

torch_tensorrt.compile(model, inputs, ir="torch_compile") 返回一个 torch.compile 封装的函数,其后端配置为 TensorRT。在使用 ir=torch_compile 的情况下,用户可以使用 torch._dynamo.mark_dynamic API (https://pytorch.ac.cn/docs/stable/torch.compiler_dynamic_shapes.html) 为输入提供动态形状信息,以避免 TensorRT 引擎的重新编译。

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224), dtype=float32)
# This indicates the dimension 0 is dynamic and the range is [1, 8]
torch._dynamo.mark_dynamic(inputs, 0, min=1, max=8)
trt_gm = torch.compile(model, backend="tensorrt")
# Compilation happens when you call the model
trt_gm(inputs)

# No recompilation of TRT engines with modified batch size
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
trt_gm(inputs_bs2)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源