• 文档 >
  • 使用 Torch-TensorRT 编译导出的程序
快捷方式

使用 Torch-TensorRT 编译导出的程序

Pytorch 2.1 引入了 torch.export API,可将 Pytorch 程序的图导出为 ExportedProgram 对象。Torch-TensorRT dynamo 前端编译这些 ExportedProgram 对象并使用 TensorRT 对其进行优化。以下是 dynamo 前端的一个简单用法

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224), dtype=torch.float32).cuda()]
exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) # Output is a torch.fx.GraphModule
trt_gm(*inputs)

注意

torch_tensorrt.dynamo.compile 是用户与 Torch-TensorRT dynamo 前端交互的主要 API。模型的输入类型应为 ExportedProgram(理想情况下是 torch.export.exporttorch_tensorrt.dynamo.trace 的输出(将在下面章节讨论)),输出类型是 torch.fx.GraphModule 对象。

可自定义设置

用户可以通过大量选项来自定义其使用 TensorRT 进行优化的设置。一些常用选项如下

  • inputs - 对于静态形状,这可以是一个 torch 张量列表或 torch_tensorrt.Input 对象列表。对于动态形状,这应该是一个 torch_tensorrt.Input 对象列表。

  • enabled_precisions - TensorRT 构建器在优化期间可以使用的精度集合。

  • truncate_long_and_double - 分别将 long 和 double 值截断为 int 和 float。

  • torch_executed_ops - 强制由 Torch 执行的算子。

  • min_block_size - 作为 TensorRT 段执行所需的最小连续算子数。

完整的选项列表可以在此处找到

注意

我们目前在 Dynamo 中不支持 INT 精度。我们的 Torchscript IR 中目前存在对此的支持。我们计划在下一个版本中为 Dynamo 实现类似的支持。

底层原理

在底层,torch_tensorrt.dynamo.compile 对图执行以下操作。

  • 降级(Lowering)- 应用降级过程来添加/删除算子,以实现最佳转换。

  • 分区(Partitioning)- 根据 min_block_sizetorch_executed_ops 字段将图划分为 Pytorch 和 TensorRT 段。

  • 转换(Conversion)- 在此阶段,Pytorch 算子被转换为 TensorRT 算子。

  • 优化(Optimization)- 转换后,我们构建 TensorRT 引擎并将其嵌入到 PyTorch 图中。

追踪(Tracing)

torch_tensorrt.dynamo.trace 可用于追踪 Pytorch 图并生成 ExportedProgram。这在内部执行了一些算子的分解,以便进行下游优化。ExportedProgram 随后可与 torch_tensorrt.dynamo.compile API 一起使用。如果您的模型具有动态输入形状,您可以使用此 torch_tensorrt.dynamo.trace 来导出具有动态形状的模型。或者,您也可以直接使用 torch.export 带约束的方式。

import torch
import torch_tensorrt

inputs = [torch_tensorrt.Input(min_shape=(1, 3, 224, 224),
                              opt_shape=(4, 3, 224, 224),
                              max_shape=(8, 3, 224, 224),
                              dtype=torch.float32)]
model = MyModel().eval()
exp_program = torch_tensorrt.dynamo.trace(model, inputs)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源