Torch-TensorRT (FX Frontend) 用户指南¶
Torch-TensorRT (FX Frontend) 是一个工具,可以通过 torch.fx 将 PyTorch 模型转换为 TensorRT 引擎,该引擎针对在 Nvidia GPU 上运行进行了优化。TensorRT 是 NVIDIA 开发的推理引擎,它包含了各种优化,包括内核融合、图优化、低精度等。该工具是用 Python 环境开发的,这使得研究人员和工程师非常容易使用此工作流程。用户使用此工具通常有几个阶段,我们将在下面进行介绍。
> Torch-TensorRT (FX Frontend) 处于 Beta 阶段,目前建议与 PyTorch nightly 版本一起使用。
# Test an example by
$ python py/torch_tensorrt/fx/example/lower_example.py
将 PyTorch 模型转换为 TensorRT 引擎¶
总的来说,欢迎用户使用 compile() 来完成从模型到 TensorRT 引擎的转换。它是一个封装 API,包含了完成此转换所需的主要步骤。请参考 examples/fx 目录下的 lower_example.py 文件中的示例用法。
def compile(
module: nn.Module,
input,
max_batch_size=2048,
max_workspace_size=33554432,
explicit_batch_dimension=False,
lower_precision=LowerPrecision.FP16,
verbose_log=False,
timing_cache_prefix="",
save_timing_cache=False,
cuda_graph_batch_size=-1,
dynamic_batch=True,
) -> nn.Module:
"""
Takes in original module, input and lowering setting, run lowering workflow to turn module
into lowered module, or so called TRTModule.
Args:
module: Original module for lowering.
input: Input for module.
max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)
max_workspace_size: Maximum size of workspace given to TensorRT.
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
lower_precision: lower_precision config given to TRTModule.
verbose_log: Enable verbose log for TensorRT if set True.
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
save_timing_cache: Update timing cache with current timing cache data if set to True.
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
dynamic_batch: batch dimension (dim=0) is dynamic.
Returns:
A torch.nn.Module lowered by TensorRT.
"""
在本节中,我们将通过一个示例来说明 FX 路径使用的主要步骤。用户可以参考 examples/fx 目录下的 fx2trt_example.py 文件。
步骤 1:使用 acc_tracer 跟踪模型
Acc_tracer 是一个继承自 FX tracer 的 tracer。它附带一个参数规范器,可以将所有参数转换为关键字参数并传递给 TRT 转换器。
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
# Build the model which needs to be a PyTorch nn.Module.
my_pytorch_model = build_model()
# Prepare inputs to the model. Inputs have to be a List of Tensors
inputs = [Tensor, Tensor, ...]
# Trace the model with acc_tracer.
acc_mod = acc_tracer.trace(my_pytorch_model, inputs)
常见错误
符号化跟踪的变量不能用作控制流的输入。这意味着模型包含动态控制流。请参考 FX 指南中“动态控制流”部分。
步骤 2:构建 TensorRT 引擎
TensorRT 有 两种不同的模式 来处理批次维度:显式批次维度和隐式批次维度。这种模式被早期版本的 TensorRT 使用,现在已弃用,但为向后兼容性而继续支持。在显式批次模式下,所有维度都是显式的,并且可以是动态的,即它们的长度可以在执行时更改。许多新功能,如动态形状和循环,仅在此模式下可用。当用户在 compile() 中将 explicit_batch_dimension 设置为 False 时,用户仍然可以选择使用隐式批次模式。我们不推荐使用它,因为它在未来 TensorRT 版本中将不再支持。
显式批次是默认模式,并且对于动态形状必须设置它。对于大多数视觉任务,如果用户希望获得与仅批次维度发生变化的隐式模式类似的效果,可以在 compile() 中启用 dynamic_batch。它有一些要求:1. 输入、输出和激活的形状是固定的,除了批次维度。2. 输入、输出和激活具有批次维度作为主要维度。3. 模型中的所有运算符都不会修改批次维度(例如置换、转置、分割等)或在批次维度上进行计算(例如求和、softmax 等)。
例如,如果我们有一个形状为 (batch, sequence, dimension) 的 3D 张量 t,则操作如 torch.transpose(0, 2)。如果这三个条件中任何一个不满足,我们就需要指定 InputTensorSpec 作为具有动态范围的输入。
import deeplearning.trt.fx2trt.converter.converters
from torch.fx.experimental.fx2trt.fx2trt import InputTensorSpec, TRTInterpreter
# InputTensorSpec is a dataclass we use to store input information.
# There're two ways we can build input_specs.
# Option 1, build it manually.
input_specs = [
InputTensorSpec(shape=(1, 2, 3), dtype=torch.float32),
InputTensorSpec(shape=(1, 4, 5), dtype=torch.float32),
]
# Option 2, build it using sample_inputs where user provide a sample
inputs = [
torch.rand((1,2,3), dtype=torch.float32),
torch.rand((1,4,5), dtype=torch.float32),
]
input_specs = InputTensorSpec.from_tensors(inputs)
# IMPORTANT: If dynamic shape is needed, we need to build it slightly differently.
input_specs = [
InputTensorSpec(
shape=(-1, 2, 3),
dtype=torch.float32,
# Currently we only support one set of dynamic range. User may set other dimensions but it is not promised to work for any models
# (min_shape, optimize_target_shape, max_shape)
# For more information refer to fx/input_tensor_spec.py
shape_ranges = [
((1, 2, 3), (4, 2, 3), (100, 2, 3)),
],
),
InputTensorSpec(shape=(1, 4, 5), dtype=torch.float32),
]
# Build a TRT interpreter. Set explicit_batch_dimension accordingly.
interpreter = TRTInterpreter(
acc_mod, input_specs, explicit_batch_dimension=True/False
)
# The output of TRTInterpreter run() is wrapped as TRTInterpreterResult.
# The TRTInterpreterResult contains required parameter to build TRTModule,
# and other informational output from TRTInterpreter run.
class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
output_names: Sequence[str]
serialized_cache: bytearray
#max_batch_size: set accordingly for maximum batch size you will use.
#max_workspace_size: set to the maximum size we can afford for temporary buffer
#lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
#sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
#force_fp32_output: force output to be fp32
#strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric #reasons.
#algorithm_selector: set up algorithm selection for certain layer
#timing_cache: enable timing cache for TensorRT
#profiling_verbosity: TensorRT logging level
trt_interpreter_result = interpreter.run(
max_batch_size=64,
max_workspace_size=1 << 25,
sparse_weights=False,
force_fp32_output=False,
strict_type_constraints=False,
algorithm_selector=None,
timing_cache=None,
profiling_verbosity=None,
)
常见错误
RuntimeError: Conversion of function xxx not currently supported! - 这意味着我们不支持 xxx 运算符。请参阅下面的“如何添加缺失的 op”部分以获取进一步说明。
步骤 3:运行模型
一种方法是使用 TRTModule,它基本上是一个 PyTorch nn.Module。
from torch_tensorrt.fx import TRTModule
mod = TRTModule(
trt_interpreter_result.engine,
trt_interpreter_result.input_names,
trt_interpreter_result.output_names)
# Just like all other PyTorch modules
outputs = mod(*inputs)
torch.save(mod, "trt.pt")
reload_trt_mod = torch.load("trt.pt")
reload_model_output = reload_trt_mod(*inputs)
到目前为止,我们详细解释了将 PyTorch 模型转换为 TensorRT 引擎的主要步骤。欢迎用户参考源代码以获取一些参数解释。在转换方案中,有两个重要的操作。一个是 acc tracer,它帮助我们将 PyTorch 模型转换为 acc 图。另一个是 FX 路径转换器,它帮助将 acc 图的操作转换为相应的 TensorRT 操作,并为其构建 TensoRT 引擎。
Acc Tracer¶
Acc tracer 是一个自定义的 FX 符号跟踪器。与原生的 FX 符号跟踪器相比,它做了更多的事情。我们主要依赖它将 PyTorch op 或内置 op 转换为 acc op。FX2TRT 使用 acc op 主要有两个目的:
在 PyTorch op 和内置 op 中有很多执行类似操作的 op,例如 torch.add、builtin.add 和 torch.Tensor.add。使用 acc tracer,我们将这三个 op 规范化为单个 acc_ops.add。这有助于减少我们需要编写的转换器数量。
acc op 只有关键字参数,这使得编写转换器更容易,因为我们不需要添加额外的逻辑来查找参数和关键字参数。
FX2TRT¶
符号化跟踪后,我们得到了 PyTorch 模型的图表示。fx2trt 利用了 fx.Interpreter 的强大功能。fx.Interpreter 逐个节点遍历整个图,并调用节点表示的函数。fx2trt 通过调用相应转换器来覆盖调用函数的原始行为。每个转换器函数都会添加相应的 TensorRT 层。
下面是一个转换器函数的示例。该装饰器用于将此转换器函数与相应的节点注册。在此示例中,我们将此转换器注册到一个目标为 acc_ops.sigmoid 的 fx 节点。
@tensorrt_converter(acc_ops.sigmoid)
def acc_ops_sigmoid(network, target, args, kwargs, name):
"""
network: TensorRT network. We'll be adding layers to it.
The rest arguments are attributes of fx node.
"""
input_val = kwargs['input']
if not isinstance(input_val, trt.tensorrt.ITensor):
raise RuntimeError(f'Sigmoid received input {input_val} that is not part '
'of the TensorRT region!')
layer = network.add_activation(input=input_val, type=trt.ActivationType.SIGMOID)
layer.name = name
return layer.get_output(0)
如何添加缺失的 Op¶
你实际上可以将其添加到任何你想要的地方,只需记住导入该文件,以便在用 acc_tracer 进行跟踪之前,所有 acc op 和映射器都会被注册。
步骤 1. 添加一个新的 acc op
待办事项:需要更多地解释 acc op 的逻辑,例如我们何时想要分解一个 op,何时想要重用其他 op。
在 acc tracer 中,如果一个节点存在映射到 acc op,我们会将其转换为 acc op。
为了实现到 acc op 的转换,需要两件事。一是应该定义一个 acc op 函数,二是应该注册一个映射。
定义 acc op 非常简单,我们首先需要一个函数,并通过此装饰器 acc_normalizer.py 将该函数注册为 acc op。例如,以下代码添加了一个名为 foo() 的 acc op,它将两个给定输入相加。
# NOTE: all acc ops should only take kwargs as inputs, therefore we need the "*"
# at the beginning.
@register_acc_op
def foo(*, input, other, alpha):
return input + alpha * other
注册映射有两种方法。一种是 register_acc_op_mapping()。让我们将一个映射从 torch.add 注册到我们上面创建的 foo()。我们需要将 register_acc_op_mapping 装饰器添加到它。
this_arg_is_optional = True
@register_acc_op_mapping(
op_and_target=("call_function", torch.add),
arg_replacement_tuples=[
("input", "input"),
("other", "other"),
("alpha", "alpha", this_arg_is_optional),
],
)
@register_acc_op
def foo(*, input, other, alpha=1.0):
return input + alpha * other
op_and_target 确定哪个节点将触发此映射。op 和 target 是 FX 节点的属性。在 acc_normalization 中,当我们遇到一个具有与 op_and_target 中设置的相同 op 和 target 的节点时,我们将触发映射。由于我们希望从 torch.add 进行映射,因此 op 将是 call_function,target 将是 torch.add。arg_replacement_tuples 决定了我们如何使用原始节点的参数和关键字参数来构建新 acc op 节点的关键字参数。 arg_replacement_tuples 中的每个元组代表一个参数映射规则。它包含两个或三个元素。第三个元素是一个布尔变量,用于确定此关键字参数在*原始节点*中是否是可选的。仅当第三个元素为 True 时,我们才需要指定它。第一个元素是原始节点中的参数名称,它将被用作 acc op 节点中名称为元组第二个元素的参数。元组的顺序很重要,因为元组的位置决定了参数在原始节点参数中的位置。我们使用此信息将原始节点中的参数映射到 acc op 节点中的关键字参数。如果以下任一情况不成立,我们则无需指定 arg_replacement_tuples:
原始节点的关键字参数和 acc op 节点的关键字参数名称不同。
存在可选参数。
注册映射的另一种方法是通过 register_custom_acc_mapper_fn()。这个函数旨在减少冗余的 op 注册,因为它允许你使用一个函数通过某些组合映射到一个或多个现有的 acc op。在函数中,你可以做任何你想做的事情。让我们用一个例子来解释它是如何工作的。
@register_acc_op
def foo(*, input, other, alpha=1.0):
return input + alpha * other
@register_custom_acc_mapper_fn(
op_and_target=("call_function", torch.add),
arg_replacement_tuples=[
("input", "input"),
("other", "other"),
("alpha", "alpha", this_arg_is_optional),
],
)
def custom_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
"""
`node` is original node, which is a call_function node with target
being torch.add.
"""
alpha = 1
if "alpha" in node.kwargs:
alpha = node.kwargs["alpha"]
foo_kwargs = {"input": node["input"], "other": node["other"], "alpha": alpha}
with node.graph.inserting_before(node):
foo_node = node.graph.call_function(foo, kwargs=foo_kwargs)
foo_node.meta = node.meta.copy()
return foo_node
在自定义映射函数中,我们构建一个 acc op 节点并返回它。这里返回的节点将接管原始节点的所有子节点 acc_normalizer.py。
最后一步是为我们添加的新 acc op 或映射函数*添加单元测试*。添加单元测试的地方在这里 test_acc_tracer.py。
步骤 2. 添加一个新的转换器
所有已开发的 acc op 转换器都在 acc_op_converter.py 中。它可以很好地展示转换器是如何添加的。
本质上,转换器是将 acc op 映射到 TensorRT 层的机制。如果我们能找到所有需要的 TensorRT 层,我们就可以使用 TensorRT API 开始为节点添加转换器。
@tensorrt_converter(acc_ops.sigmoid)
def acc_ops_sigmoid(network, target, args, kwargs, name):
"""
network: TensorRT network. We'll be adding layers to it.
The rest arguments are attributes of fx node.
"""
input_val = kwargs['input']
if not isinstance(input_val, trt.tensorrt.ITensor):
raise RuntimeError(f'Sigmoid received input {input_val} that is not part '
'of the TensorRT region!')
layer = network.add_activation(input=input_val, type=trt.ActivationType.SIGMOID)
layer.name = name
return layer.get_output(0)
我们需要使用 tensorrt_converter 装饰器来注册转换器。装饰器的参数是我们需要的 fx 节点的 target。在转换器中,我们可以在 kwargs 中找到 fx 节点的输入。例如,原始节点是 acc_ops.sigmoid,它在 acc_ops.py 中只有一个参数“input”。我们获取输入并检查它是否是 TensorRT 张量。之后,我们向 TensorRT 网络添加一个 sigmoid 层,并返回该层的输出。我们返回的输出将通过 fx.Interpreter 传递给 acc_ops.sigmoid 的子节点。
如果找不到与节点功能相同的 TensorRT 层怎么办?
在这种情况下,我们需要做更多的工作。TensorRT 提供插件,作为自定义层。*我们尚未实现此功能。一旦启用,我们将进行更新*。
最后一步是为我们添加的新转换器添加单元测试。用户可以在此 文件夹 中添加相应的单元测试。