降低阶段¶
降低阶段由一系列 Pass 组成,这些 Pass 是将图从高层表示映射到低层表示的操作。每个 Pass 执行特定的操作,例如内联方法调用。其目的是显著减少转换阶段在实际映射到 TensorRT 时需要处理的内容。我们旨在实现更接近 1:1 的算子转换,而不是寻找适用的子图,从而限制转换器的数量并缩小每个转换器的范围。
您可以通过将日志级别设置为 Level::kGraph 来查看每个 Pass 的效果。
使用的 Pass¶
EliminateCommonSubexpression¶
移除图中的公共子表达式。
Eliminate Dead Code¶
消除死代码将检查节点是否具有副作用,如果有副作用则不删除它。
Eliminate Exception Or Pass Pattern¶
脚本化模块中的常见模式是维度检查,如果输入维度不符合预期,则会抛出异常。
%1013 : bool = aten::ne(%1012, %24) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:11
= prim::If(%1013) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:8
block0():
= prim::RaiseException(%23) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:249:12
-> ()
block1():
-> ()
由于我们在编译时解决了所有这些问题,并且 TensorRT 图中没有异常,因此我们将其删除。
Eliminate Redundant Guards¶
消除算子输出完全由其输入确定的算子的冗余检查。也就是说,如果算子的输入已检查,则允许我们删除算子输出上的检查。
Freeze Module¶
冻结属性并内联常量和模块。在图中传播常量。
Fuse AddMM Branches¶
脚本化模块中的常见模式是不同维度的张量使用不同的构造来实现线性层。我们将这些不同的变体融合到一个单一的变体中,该变体将被 Unpack AddMM Pass 捕获。
%ret : Tensor = prim::If(%622)
block0():
%ret.1 : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
-> (%ret.1)
block1():
%output.1 : Tensor = aten::matmul(%x9.1, %3677)
%output0.1 : Tensor = aten::add_(%output.1, %self.fc.bias, %3)
-> (%output0.1)
我们将这组块融合为如下所示的图:
%ret : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
Fuse Linear¶
匹配 aten::linear 模式并将其融合为单个 aten::linear。此 Pass 将 JIT 生成的 addmm 或 matmul + add 融合回 linear。
Fuse Flatten Linear¶
当输入层维度高于 1D 时,TensorRT 会隐式地将输入层展平为全连接层。因此,当存在 aten::flatten -> aten::linear 模式时,我们会删除 aten::flatten。
Lower Graph¶
给定一个方法图,其中第一个参数为 %self,将其降低为一个图,其中所有属性访问都被替换为图的显式输入(而不是在 %self 上执行的 prim::GetAttr 的结果)。返回一个元组 (graph, parameters),其中图的最后 module.parameters.size() 个输入是此方法中使用的可训练参数。其余输入是函数的实际输入。
Lower Tuples¶
LowerSimpleTuples:
移除 TupleConstruct 和 TupleUnpack 被匹配的元组,但将 if 语句、循环以及作为输入/输出的元组保留在原位。
LowerAllTuples:
移除 _所有_ 元组,并在无法移除时引发错误。ONNX 会使用此方法来确保转换前没有元组,但它不适用于输入包含元组的图。
Module Fallback¶
模块回退由两个必须成对运行的降低 Pass 组成。第一个 Pass 在冻结之前运行,用于在图中放置分隔符,围绕应在 PyTorch 中运行的模块。第二个 Pass 在冻结后标记这些分隔符之间的节点,以指示它们应在 PyTorch 中运行。
NotateModuleForFallback
在冻结前将分隔节点放置在模块调用周围,以指示图中哪些节点应在 PyTorch 中运行。
MarkNodesForFallback
查找分隔符,然后标记分隔符之间的所有节点,以告知分区器在 PyTorch 中运行它们。
Peephole Optimize¶
此优化 Pass 的目的是捕获您可能希望进行的所有小型、易于捕获的窥孔优化。
- 目前,它执行以下操作:
消除无操作的“expand”节点。
将 x.t().t() 简化为 x。
Remove Contiguous¶
移除 contiguous 算子,因为我们使用的是 TensorRT,内存已经是连续的。
Remove Dropout¶
移除 dropout 算子,因为我们正在进行推理。
Remove To¶
移除执行类型转换的 aten::to 算子,因为 TensorRT 会自行管理。重要的是,这是最后一个运行的 Pass 之一,以便其他 Pass 有机会将必需的类型转换算子移出主命名空间。
Unpack AddMM¶
将 aten::addmm 解包为 aten::matmul 和 aten::add_(并添加一个额外的 trt::const 算子将偏置冻结在 TensorRT 图中)。这使我们可以重用 aten::matmul 和 aten::add_ 转换器,而无需专门的转换器。
Unpack LogSoftmax¶
将 aten::logsoftmax 解包为 aten::softmax 和 aten::log。这使我们可以重用 aten::softmax 和 aten::log 转换器,而无需专门的转换器。
Unroll Loops¶
展开兼容循环(例如,足够短的循环)的操作,以便您只需遍历循环一次。
Replace Tile with Repeat¶
移除 dropout 算子,因为我们正在进行推理。