• 文档 >
  • 使用 Torch-TensorRT dynamo 后端编译 FLUX.1-dev 模型
快捷方式

使用 Torch-TensorRT dynamo 后端编译 FLUX.1-dev 模型

本示例说明了使用 Torch-TensorRT 优化的最先进模型 FLUX.1-dev

FLUX.1 [dev] 是一个 120 亿参数的整流流(rectified flow)Transformer,能够根据文本描述生成图像。它是一个开放权重、经过指导蒸馏的模型,仅供非商业用途。

要运行此演示,您需要访问 Flux 模型(如果您还没有,请在 FLUX.1-dev 页面上请求访问)并安装以下依赖项

pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3"

FLUX.1-dev 管道有不同的组成部分,例如 transformervaetext_encodertokenizerscheduler。在此示例中,我们演示了优化模型中的 transformer 组件(该组件通常消耗端到端扩散延迟的 95% 以上)。

import register_sdpa  # Register SDPA as a standalone operator

导入以下库

import torch
import torch_tensorrt
from diffusers import FluxPipeline
from torch.export._trace import _export

定义 FLUX-1.dev 模型

使用 FluxPipeline 类加载 FLUX-1.dev 预训练管道。 FluxPipeline 包含生成图像所需的 transformervaetext_encodertokenizerscheduler 等不同组件。我们使用 torch_dtype 参数以 FP16 精度加载权重。

DEVICE = "cuda:0"
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.float16,
)

# Store the config and transformer backbone
config = pipe.transformer.config
backbone = pipe.transformer.to(DEVICE)

使用 torch.export 导出骨干网络

定义虚拟输入及其各自的动态形状。我们导出具有动态形状的 transformer 骨干网络,batch_size=2 是因为 0/1 特殊化

batch_size = 2
BATCH = torch.export.Dim("batch", min=1, max=2)
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
# To see this recommendation, you can try exporting using min=1, max=4096
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
dynamic_shapes = {
    "hidden_states": {0: BATCH},
    "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
    "pooled_projections": {0: BATCH},
    "timestep": {0: BATCH},
    "txt_ids": {0: SEQ_LEN},
    "img_ids": {0: IMG_ID},
    "guidance": {0: BATCH},
    "joint_attention_kwargs": {},
    "return_dict": None,
}
# The guidance factor is of type torch.float32
dummy_inputs = {
    "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
        DEVICE
    ),
    "encoder_hidden_states": torch.randn(
        (batch_size, 512, 4096), dtype=torch.float16
    ).to(DEVICE),
    "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
        DEVICE
    ),
    "timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE),
    "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
    "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
    "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
    "joint_attention_kwargs": {},
    "return_dict": False,
}
# This will create an exported program which is going to be compiled with Torch-TensorRT
ep = _export(
    backbone,
    args=(),
    kwargs=dummy_inputs,
    dynamic_shapes=dynamic_shapes,
    strict=False,
    prefer_deferred_runtime_asserts_over_guards=True,
)

Torch-TensorRT 编译

注意

编译需要高内存(> 80GB)的 GPU,因为 TensorRT 以 FP32 精度存储权重。这是一个已知问题,将在未来得到解决。

我们使用 use_fp32_acc=True 启用 FP32 矩阵乘法累加,以通过引入到 FP32 节点中的转换来确保精度得到保留。我们还启用了显式类型设置,以确保 TensorRT 尊重用户设置的数据类型,这是 FP32 矩阵乘法累加的要求。由于这是一个 120 亿参数的模型,在 H100 GPU 上大约需要 20-30 分钟才能编译。该模型完全可转换,并生成单个 TensorRT 引擎。

trt_gm = torch_tensorrt.dynamo.compile(
    ep,
    inputs=dummy_inputs,
    enabled_precisions={torch.float32},
    truncate_double=True,
    min_block_size=1,
    use_fp32_acc=True,
    use_explicit_typing=True,
    immutable_weights=False,
    offload_module_to_cpu=True,
)

后处理

释放导出的程序和 pipe.transformer 占用的 GPU 内存。将 Flux 管道中的 transformer 设置为 Torch-TRT 编译后的模型。

pipe.transformer = None
pipe.to(DEVICE)
pipe.transformer = trt_gm
del ep
torch.cuda.empty_cache()
pipe.transformer.config = config
trt_gm.device = torch.device("cuda")

使用提示生成图像

提供提示和要生成的图像的文件名。这里我们使用提示 A golden retriever holding a sign to code

# Function which generates images from the flux pipeline
def generate_image(pipe, prompt, image_name):
    seed = 42
    with torch.no_grad():
        image = pipe(
            prompt,
            output_type="pil",
            num_inference_steps=20,
            generator=torch.Generator("cuda").manual_seed(seed),
        ).images[0]
        image.save(f"{image_name}.png")
        print(f"Image generated using {image_name} model saved as {image_name}.png")


generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")

生成的图像如下所示

tutorials/_rendered_examples/dynamo/dog_code.png

脚本总运行时间: ( 0 分 0.000 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源