降级阶段¶
降级阶段由多个通道组成,这些通道是将图形从高级表示映射到低级表示的操作。每个通道都做一些特定的事情,例如内联方法调用。其目的是显著减少转换阶段在实际映射到 TensorRT 时需要处理的内容。我们的目标是接近 1 对 1 的操作转换,而不是寻找适用的子图,从而限制转换器的数量并缩小每个转换器的范围。
通过将日志级别设置为 Level::kGraph
,可以看到每个通道的效果
使用的通道¶
消除公共子表达式¶
删除图中的公共子表达式
消除死代码¶
死代码消除将检查节点是否有副作用,如果有则不删除它。
消除异常或通过模式¶
在脚本模块中常见的模式是维度保护,如果输入维度不符合预期,它将抛出异常。
%1013 : bool = aten::ne(%1012, %24) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:11
= prim::If(%1013) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:8
block0():
= prim::RaiseException(%23) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:249:12
-> ()
block1():
-> ()
由于我们在编译时解决了所有这些问题,并且 TensorRT 图中没有异常,因此我们只需将其删除。
消除冗余保护¶
消除操作的冗余保护,这些操作的输出完全由其输入确定,即如果这些操作的输入受到保护,我们就可以删除操作输出上的保护
冻结模块¶
冻结属性并内联常量和模块。在图中传播常量。
融合 AddMM 分支¶
脚本模块中常见的模式是不同维度的张量使用不同的构造来实现线性层。我们将这些不同的变体融合为一个,这将由 Unpack AddMM 通道捕获。
%ret : Tensor = prim::If(%622)
block0():
%ret.1 : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
-> (%ret.1)
block1():
%output.1 : Tensor = aten::matmul(%x9.1, %3677)
%output0.1 : Tensor = aten::add_(%output.1, %self.fc.bias, %3)
-> (%output0.1)
我们将这组块融合到这样的图中
%ret : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
融合线性¶
匹配 aten::linear
模式并将其融合为一个 aten::linear
此通道将 JIT 生成的 addmm 或 matmul + add 融合回线性
融合扁平线性¶
当输入层高于 1D 时,TensorRT 会隐式地将其扁平化为全连接层。因此,当存在 aten::flatten
-> aten::linear
模式时,我们删除 aten::flatten
。
降低图¶
给定一个方法图,其第一个参数为 %self,将其降低到所有属性访问都替换为图的显式输入(而不是在 %self 上执行的 prim::GetAttr 的结果)的图。返回一个元组 (graph, parameters),其中图的最后 module.parameters.size() 输入是此方法中使用的可训练参数。其余输入是函数的真实输入。
降低元组¶
LowerSimpleTuples
:
删除 TupleConstruct 和 TupleUnpack 匹配的元组,但保留 if 语句、循环以及作为输入/输出的元组
LowerAllTuples
:
删除 _所有_ 元组,如果某些元组无法删除则引发错误,这由 ONNX 用于确保在转换前没有元组,但对输入包含元组的图不起作用。
模块回退¶
模块回退由两个必须成对运行的降级通道组成。第一个通道在冻结之前运行,以在图中放置定界符,围绕应在 PyTorch 中运行的模块。第二个通道在冻结之后标记这些定界符之间的节点,以表示它们应在 PyTorch 中运行。
NotateModuleForFallback
在冻结前在模块调用周围放置定界节点,以指示图中应在 PyTorch 中运行的节点位置
MarkNodesForFallback
查找定界符,然后标记定界符之间的所有节点,告诉分区在 PyTorch 中运行它们
窥孔优化¶
此优化通道的目的是捕获您可能感兴趣的所有小型、易于捕获的窥孔优化。
- 目前,它执行
消除无操作的“expand”节点
简单地将 x.t().t() 转换为 x
移除 Contiguous¶
移除 contiguous 运算符,因为我们正在使用 TensorRT 内存,它已经是连续的。
移除 Dropout¶
移除 dropout 运算符,因为我们正在进行推理。
移除 To¶
移除执行类型转换的 aten::to
运算符,因为 TensorRT 会自行管理。重要的是,这是最后运行的通道之一,以便其他通道有机会将所需的类型转换运算符移出主命名空间。
解包 AddMM¶
将 aten::addmm
解包为 aten::matmul
和 aten::add_
(带有一个额外的 trt::const
操作以冻结 TensorRT 图中的偏差)。这使我们能够重用 aten::matmul
和 aten::add_
转换器,而无需专用转换器。
解包 LogSoftmax¶
将 aten::logsoftmax
解包为 aten::softmax
和 aten::log
。这使我们能够重用 aten::softmax
和 aten::log
转换器,而无需专用转换器。
展开循环¶
展开兼容循环(例如足够短的循环)的操作,这样您只需通过循环一次。
用 Repeat 替换 Tile¶
移除 dropout 运算符,因为我们正在进行推理。