创建一个 TorchScript 模块¶
TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。PyTorch 有关于如何做到这一点的详细文档 https://pytorch.ac.cn/tutorials/beginner/Intro_to_TorchScript_tutorial.html,但这里简要介绍一下关键背景信息和过程。
PyTorch 程序基于 Module
,可用于构成更高级别的模块。Module
包含一个用于设置模块、参数和子模块的构造函数,以及一个描述在模块被调用时如何使用参数和子模块的前向函数。
例如,我们可以这样定义一个 LeNet 模块
1import torch.nn as nn
2import torch.nn.functional as F
3
4
5class LeNetFeatExtractor(nn.Module):
6 def __init__(self):
7 super(LeNetFeatExtractor, self).__init__()
8 self.conv1 = nn.Conv2d(1, 6, 3)
9 self.conv2 = nn.Conv2d(6, 16, 3)
10
11 def forward(self, x):
12 x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
13 x = F.max_pool2d(F.relu(self.conv2(x)), 2)
14 return x
15
16
17class LeNetClassifier(nn.Module):
18 def __init__(self):
19 super(LeNetClassifier, self).__init__()
20 self.fc1 = nn.Linear(16 * 6 * 6, 120)
21 self.fc2 = nn.Linear(120, 84)
22 self.fc3 = nn.Linear(84, 10)
23
24 def forward(self, x):
25 x = torch.flatten(x, 1)
26 x = F.relu(self.fc1(x))
27 x = F.relu(self.fc2(x))
28 x = self.fc3(x)
29 return x
30
31
32class LeNet(nn.Module):
33 def __init__(self):
34 super(LeNet, self).__init__()
35 self.feat = LeNetFeatExtractor()
36 self.classifier = LeNetClassifier()
37
38 def forward(self, x):
39 x = self.feat(x)
40 x = self.classifier(x)
41 return x
.
显然,你可能希望将这样一个简单的模型整合到一个模块中,但我们在这里可以看到 PyTorch 的可组合性。
从这里开始,有两条路径可以从 PyTorch Python 代码转换到 TorchScript 代码:跟踪(Tracing)和脚本化(Scripting)。
跟踪会遵循模块被调用时的执行路径并记录所发生的事情。要跟踪我们的 LeNet 模块的实例,我们可以使用一个示例输入来调用 torch.jit.trace
。
import torch
model = LeNet()
input_data = torch.empty([1, 1, 32, 32])
traced_model = torch.jit.trace(model, input_data)
脚本化实际上是用一个编译器来检查你的代码,并生成一个等效的 TorchScript 程序。区别在于,由于跟踪是遵循模块的执行过程,它无法捕捉到例如控制流这样的东西。而通过处理 Python 代码,编译器可以包含这些组件。我们可以通过调用 torch.jit.script
在我们的 LeNet 模块上运行脚本编译器。
import torch
model = LeNet()
script_model = torch.jit.script(model)
选择哪条路径都有其原因,PyTorch 文档中有关于如何选择的信息。从 Torch-TensorRT 的角度来看,对跟踪模块的支持更好(即你的模块更有可能编译成功),因为它不包含完整编程语言的所有复杂性,尽管两条路径都支持。
在对你的模块进行脚本化或跟踪后,你会得到一个 TorchScript 模块。它包含了运行该模块所需的代码和参数,这些都存储在 Torch-TensorRT 可以使用的中间表示(IR)中。
以下是 LeNet 跟踪模块的 IR 看起来的样子
graph(%self.1 : __torch__.___torch_mangle_10.LeNet,
%input.1 : Float(1, 1, 32, 32)):
%129 : __torch__.___torch_mangle_9.LeNetClassifier = prim::GetAttr[name="classifier"](%self.1)
%119 : __torch__.___torch_mangle_5.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self.1)
%137 : Tensor = prim::CallMethod[name="forward"](%119, %input.1)
%138 : Tensor = prim::CallMethod[name="forward"](%129, %137)
return (%138)
以及 LeNet 脚本化模块的 IR
graph(%self : __torch__.LeNet,
%x.1 : Tensor):
%2 : __torch__.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self)
%x.3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # x.py:38:12
%5 : __torch__.LeNetClassifier = prim::GetAttr[name="classifier"](%self)
%x.5 : Tensor = prim::CallMethod[name="forward"](%5, %x.3) # x.py:39:12
return (%x.5)
你可以看到 IR 保留了我们 Python 代码中的模块结构。
在 Python 中使用 TorchScript¶
TorchScript 模块的运行方式与普通 PyTorch 模块相同。你可以使用 forward
方法来运行前向传播,或者直接调用模块 torch_script_module(in_tensor)
。JIT 编译器会即时编译和优化模块,然后返回结果。
将 TorchScript 模块保存到磁盘¶
对于跟踪或脚本化的模块,你都可以使用以下命令将模块保存到磁盘
import torch
model = LeNet()
script_model = torch.jit.script(model)
script_model.save("lenet_scripted.ts")