• 文档 >
  • 可变的 Torch TensorRT 模块
快捷方式

可变 Torch TensorRT 模块

我们将演示如何轻松使用可变 Torch TensorRT 模块来编译、交互和修改 TensorRT 图模块。

编译 Torch-TensorRT 模块非常简单,但修改已编译的模块可能具有挑战性,尤其是在维护 PyTorch 模块与相应的 Torch-TensorRT 模块之间的状态和连接时。在预编译 (AoT) 场景中,将 Torch TensorRT 与复杂管道(如 Hugging Face Stable Diffusion 管道)集成变得更加困难。可变 Torch TensorRT 模块旨在解决这些挑战,使与 Torch-TensorRT 模块的交互比以往任何时候都更加容易。

在本教程中,我们将介绍:
  1. 可变 Torch TensorRT 模块与 ResNet 18 的示例工作流程

  2. 保存可变 Torch TensorRT 模块

  3. 与 Huggingface 管道在 LoRA 用例中的集成

  4. 可变 Torch TensorRT 模块的动态形状用法

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from diffusers import DiffusionPipeline

np.random.seed(5)
torch.manual_seed(5)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]

使用设置初始化可变 Torch TensorRT 模块。

settings = {
    "use_python_runtime": False,
    "enabled_precisions": {torch.float32},
    "immutable_weights": False,
}

model = models.resnet18(pretrained=True).to("cuda").eval()
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
with torch.no_grad():
    mutable_module(*inputs)

对可变模块进行修改。

对可变模块的更改可能会触发重新拟合或重新编译。例如,加载不同的 state_dict 并设置新的权重值将触发重新拟合,而向模型添加模块将触发重新编译。

model2 = models.resnet18(pretrained=False).to("cuda").eval()
mutable_module.load_state_dict(model2.state_dict())


# Check the output
# The refit happens while you call the mutable module again.
with torch.no_grad():
    expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
    assert torch.allclose(
        expected_output, refitted_output, 1e-2, 1e-2
    ), "Refit Result is not correct. Refit failed"

print("Refit successfully!")

保存可变 Torch TensorRT 模块

# Currently, saving is only enabled when "use_python_runtime" = False in settings
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")

Huggingface 的 Stable Diffusion

with torch.no_grad():
    settings = {
        "use_python_runtime": True,
        "enabled_precisions": {torch.float16},
        "immutable_weights": False,
    }

    model_id = "stabilityai/stable-diffusion-xl-base-1.0"
    device = "cuda:0"

    prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
    negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"

    pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe.to(device)

    # The only extra line you need
    pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
    BATCH = torch.export.Dim("BATCH", min=2, max=24)
    _HEIGHT = torch.export.Dim("_HEIGHT", min=16, max=32)
    _WIDTH = torch.export.Dim("_WIDTH", min=16, max=32)
    HEIGHT = 4 * _HEIGHT
    WIDTH = 4 * _WIDTH
    args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {})
    kwargs_dynamic_shapes = {
        "encoder_hidden_states": {0: BATCH},
        "added_cond_kwargs": {
            "text_embeds": {0: BATCH},
            "time_ids": {0: BATCH},
        },
        "return_dict": None,
    }
    pipe.unet.set_expected_dynamic_shape_range(
        args_dynamic_shapes, kwargs_dynamic_shapes
    )
    image = pipe(
        prompt,
        negative_prompt=negative,
        num_inference_steps=30,
        height=1024,
        width=768,
        num_images_per_prompt=2,
    ).images[0]
    image.save("./without_LoRA_mutable.jpg")

    # Standard Huggingface LoRA loading procedure
    pipe.load_lora_weights(
        "stablediffusionapi/load_lora_embeddings",
        weight_name="all-disney-princess-xl-lo.safetensors",
        adapter_name="lora1",
    )
    pipe.set_adapters(["lora1"], adapter_weights=[1])
    pipe.fuse_lora()
    pipe.unload_lora_weights()

    # Refit triggered
    image = pipe(
        prompt,
        negative_prompt=negative,
        num_inference_steps=30,
        height=1024,
        width=1024,
        num_images_per_prompt=1,
    ).images[0]
    image.save("./with_LoRA_mutable.jpg")

将可变 Torch TensorRT 模块与动态形状一起使用

在向 MutableTorchTensorRTModule 添加动态形状提示时,形状提示应严格遵循传递给 forward 函数的 arg_inputs 和 kwarg_inputs 的语义,并且不应省略任何条目(kwarg_inputs 中的 None 除外)。如果输入中存在嵌套字典/列表,则该条目的动态形状也应为嵌套字典/列表。如果某个输入不需要动态形状,则应为该输入提供一个空字典作为形状提示。请注意,您应该排除值为 None 的关键字参数,因为这些参数将被过滤掉。

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a, b, c={}):
        x = torch.matmul(a, b)
        x = torch.matmul(c["a"], c["b"].T)
        print(c["b"][0])
        x = 2 * c["b"]
        return x


device = "cuda:0"
model = Model().to(device).eval()
inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device))
kwargs = {
    "c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)},
}
dim_0 = torch.export.Dim("dim", min=1, max=50)
dim_1 = torch.export.Dim("dim", min=1, max=50)
dim_2 = torch.export.Dim("dim2", min=1, max=50)
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
kwarg_dynamic_shapes = {
    "c": {
        "a": {},
        "b": {0: dim_2},
    },  # a's shape does not change so we give it an empty dict
}
# Export the model first with custom dynamic shape constraints
model = torch_trt.MutableTorchTensorRTModule(model, min_block_size=1)
model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
# Compile
with torch.no_grad():
    model(*inputs, **kwargs)
    # Change input shape
    inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device))
    kwargs_2 = {
        "c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)},
    }
    # Run without recompiling
    model(*inputs_2, **kwargs_2)

将可变 Torch TensorRT 模块与持久缓存一起使用

通过利用引擎缓存,我们可以绕过引擎编译,节省大量时间。

import os

from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH

model = models.resnet18(pretrained=True).to("cuda").eval()

times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
model = torch_trt.MutableTorchTensorRTModule(
    model,
    use_python_runtime=True,
    enabled_precisions={torch.float},
    min_block_size=1,
    immutable_weights=False,
    cache_built_engines=True,
    reuse_cached_engines=True,
    engine_cache_size=1 << 30,  # 1GB
)


def remove_timing_cache(path=TIMING_CACHE_PATH):
    if os.path.exists(path):
        os.remove(path)


remove_timing_cache()

with torch.no_grad():
    for i in range(4):
        inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]

        start.record()
        model(*inputs)  # Recompile
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

print("----------------dynamo_compile----------------")
print("Without engine caching, used:", times[0], "ms")
print("With engine caching used:", times[1], "ms")
print("With engine caching used:", times[2], "ms")
print("With engine caching used:", times[3], "ms")

脚本总运行时间: ( 0 分 0.000 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源