直接通过 PyTorch 使用 Torch-TensorRT TorchScript 前端¶
您现在将能够直接从 PyTorch API 访问 TensorRT。使用此功能的过程与 在 Python 中使用 Torch-TensorRT 中描述的编译工作流程非常相似。
首先将 torch_tensorrt
加载到您的应用程序中。
import torch
import torch_tensorrt
然后,对于一个给定的 TorchScript 模块,您可以使用 torch._C._jit_to_backend("tensorrt", ...)
API 通过 TensorRT 对其进行编译。
import torchvision.models as models
model = models.mobilenet_v2(pretrained=True)
script_model = torch.jit.script(model)
Torch-TensorRT 中的 compile
API 假定您要编译模块的 forward
函数,而 convert_method_to_trt_engine
则将指定函数转换为 TensorRT 引擎。与它们不同,后端 API 将接受一个字典,该字典将要编译的函数名称映射到编译规范(Compilation Spec)对象,这些对象包装了您提供给 compile
的同类字典。有关编译规范字典的更多信息,请查看 Torch-TensorRT TensorRTCompileSpec
API 的文档。
spec = {
"forward": torch_tensorrt.ts.TensorRTCompileSpec(
**{
"inputs": [torch_tensorrt.Input([1, 3, 300, 300])],
"enabled_precisions": {torch.float, torch.half},
"refit": False,
"debug": False,
"device": {
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": True,
},
"capability": torch_tensorrt.EngineCapability.default,
"num_avg_timing_iters": 1,
}
)
}
现在,要使用 Torch-TensorRT 进行编译,请将目标模块对象和规范字典提供给 torch._C._jit_to_backend("tensorrt", ...)
trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)
要运行,请显式调用您想要运行的方法的函数(这与在标准 PyTorch 中可以直接在模块本身上调用的方式不同)。
input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half)
print(trt_model.forward(input))