• 文档 >
  • 保存用 Torch-TensorRT 编译的模型
快捷方式

保存使用 Torch-TensorRT 编译的模型

可以使用 torch_tensorrt.save API 保存用 Torch-TensorRT 编译的模型。

Dynamo IR

Torch-TensorRT 的 ir=dynamo 编译的输出类型默认为 torch.fx.GraphModule 对象。我们可以通过指定 output_format 标志,将此对象保存为 TorchScript (torch.jit.ScriptModule)、ExportedProgram (torch.export.ExportedProgram) 或 PT2 格式。以下是 output_format 可接受的选项:

  • exported_program:这是默认选项。我们首先对 graphmodule 执行转换,然后使用 torch.export.save 保存模块。

  • torchscript:我们通过 torch.jit.trace 追踪 graphmodule,并使用 torch.jit.save 保存它。

  • PT2 格式:这是 PyTorch 模型的下一代运行时,允许它们在 Python 和 C++ 中运行。

a) ExportedProgram

这是一个使用示例:

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs)

# Later, you can load it and run inference
model = torch.export.load("trt.ep").module()
model(*inputs)

b) Torchscript

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", arg_inputs=inputs)

# Later, you can load it and run inference
model = torch.jit.load("trt.ts").cuda()
model(*inputs)

Torchscript IR

在 Torch-TensorRT 1.X 版本中,使用 Torchscript IR 是编译和运行 Torch-TensorRT 推理的主要方式。对于 ir=ts,此行为在 2.X 版本中保持不变。

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_ts = torch_tensorrt.compile(model, ir="ts", arg_inputs=inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(*inputs)

加载模型

我们可以直接使用 PyTorch 的 torch.jit.loadtorch.export.load API 加载 torchscript 或 exported_program 模型。另外,我们也提供了一个轻量级的包装器 torch_tensorrt.load(file_path),它可以加载上述任何一种模型类型。

这是一个使用示例:

import torch
import torch_tensorrt

# file_path can be trt.ep or trt.ts file obtained via saving the model (refer to the above section)
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
model = torch_tensorrt.load(<file_path>).module()
model(*inputs)

b) PT2 格式

PT2 是一种新格式,未来将允许模型在 Python 之外运行。它利用 AOTInductor 为那些不会在 TensorRT 中运行的组件生成内核。

以下是如何在 Python 中使用 AOTInductor 保存和加载 Torch-TensorRT 模块的示例:

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.pt2", arg_inputs=inputs, output_format="aot_inductor", retrace=True)

# Later, you can load it and run inference
model = torch._inductor.aoti_load_package("trt.pt2")
model(*inputs)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源