评价此页

torch.export 流程演示、常见挑战及解决方案#

作者: Ankith Gunapal, Jordi Ramon, Marcos Carranza

torch.export 教程简介 中,我们学习了如何使用 torch.export。本教程在此基础上进行了扩展,探讨了导出常用模型的代码流程,并解决了在使用 torch.export 时可能出现的常见挑战。

在本教程中,你将学习如何为以下用例导出模型:

选择这四个模型是为了演示 torch.export 的独特功能,以及在实现过程中面临的一些实际考量和问题。

先决条件#

  • PyTorch 2.4 或更高版本

  • torch.export 和 PyTorch Eager 推理有基础了解。

torch.export 的关键要求:无图中断 (No graph break)#

torch.compile 通过使用 JIT 将 PyTorch 代码编译为优化内核来加速 PyTorch 代码。它使用 TorchDynamo 优化给定模型并创建一个优化图,然后使用 API 中指定的后端将其降低到硬件。当 TorchDynamo 遇到不支持的 Python 特性时,它会中断计算图,让默认的 Python 解释器处理不支持的代码,然后再恢复捕获图。这种计算图的中断称为 图中断 (graph break)

torch.exporttorch.compile 之间的关键区别之一是 torch.export 不支持图中断,这意味着你正在导出的整个模型或模型的一部分必须是一个单一的图。这是因为处理图中断涉及使用默认的 Python 求值来解释不支持的操作,这与 torch.export 的设计初衷不兼容。你可以通过此 链接 阅读有关各种 PyTorch 框架之间差异的详细信息。

你可以使用以下命令识别程序中的图中断:

TORCH_LOGS="graph_breaks" python <file_name>.py

你需要修改程序以消除图中断。一旦解决,就可以导出模型了。PyTorch 为 torch.compile 在流行的 HuggingFace 和 TIMM 模型上运行 每日基准测试。这些模型中的大多数都没有图中断。

本配方中的模型没有图中断,但使用 torch.export 时会失败。

视频分类#

MViT 是一类基于 多尺度视觉 Transformer 的模型。该模型已使用 Kinetics-400 数据集 进行视频分类训练。该模型配合相关数据集可用于游戏环境中的动作识别。

下面的代码通过 batch_size=2 进行跟踪来导出 MViT,然后检查 ExportedProgram 是否可以以 batch_size=4 运行。

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb

model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
exported_program = torch.export.export(
    model,
    (input_frames,),
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

错误:静态批次大小 (Static batch size)#

    raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4

默认情况下,导出流程会假设所有输入形状都是静态的来跟踪程序,因此如果你使用与跟踪时不同的输入形状运行程序,将会遇到错误。

解决方案#

为了解决该错误,我们将输入的第一个维度(batch_size)指定为动态,并指定 batch_size 的预期范围。在下面展示的修正示例中,我们指定预期的 batch_size 范围可以是 1 到 16。需要注意的一个细节是 min=2 并非 bug,这在 0/1 特化问题 中有解释。有关 torch.export 动态形状的详细描述可以在导出教程中找到。下面展示的代码演示了如何使用动态批次大小导出 mViT。

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb


model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)

# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
batch_dim = torch.export.Dim("batch", min=2, max=16)
exported_program = torch.export.export(
    model,
    (input_frames,),
    # Specify the first dimension of the input x as dynamic
    dynamic_shapes={"x": {0: batch_dim}},
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

自动语音识别#

自动语音识别 (ASR) 是利用机器学习将口述语言转录为文本的技术。Whisper 是 OpenAI 的一种基于 Transformer 的编码器-解码器模型,它在 68 万小时的标记数据上进行了 ASR 和语音翻译训练。下面的代码尝试导出 whisper-tiny 模型用于 ASR。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,))

错误:使用 TorchDynamo 进行严格跟踪 (strict tracing)#

torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache'

默认情况下,torch.export 使用 TorchDynamo(一个字节码分析引擎)来跟踪你的代码,它会对你的代码进行符号分析并构建图。这种分析提供了更强的安全保证,但并非支持所有 Python 代码。当我们使用默认的严格模式导出 whisper-tiny 模型时,由于存在不支持的特性,Dynamo 通常会报错。要了解为什么 Dynamo 会报错,可以参考此 GitHub issue

解决方案#

为了解决上述错误,torch.export 支持 non_strict 模式。在该模式下,程序通过 Python 解释器进行跟踪,其工作方式类似于 PyTorch 的 Eager 执行。唯一的区别是所有 Tensor 对象都将被替换为 ProxyTensors,它会将所有操作记录到一个图中。通过使用 strict=False,我们可以成功导出程序。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False)

图像字幕生成#

图像字幕生成 是用文字描述图像内容的一项任务。在游戏背景下,图像字幕生成可用于通过动态生成场景中各种游戏对象的文字描述来增强游戏体验,从而为玩家提供额外细节。BLIP 是由 SalesForce Research 发布 的一种流行的图像字幕生成模型。下面的代码尝试以 batch_size=1 导出 BLIP。

import torch
from models.blip import blip_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
image = torch.randn(1, 3,384,384).to(device)
caption_input = ""

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)

错误:无法改变具有冻结存储的张量 (Cannot mutate tensors with frozen storage)#

在导出模型时,可能会因为模型实现中包含某些 torch.export 尚不支持的 Python 操作而失败。其中一些失败可能有变通方法。BLIP 是一个原始模型报错的例子,通过在代码中进行小改动即可解决。torch.exportExportDB 中列出了支持和不支持操作的常见情况,并展示了如何修改代码以使其兼容导出。

File "/BLIP/models/blip.py", line 112, in forward
    text.input_ids[:,0] = self.tokenizer.bos_token_id
  File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
RuntimeError: cannot mutate tensors with frozen storage

解决方案#

克隆导出失败的 张量

text.input_ids = text.input_ids.clone() # clone the tensor
text.input_ids[:,0] = self.tokenizer.bos_token_id

注意

在 PyTorch 2.7 夜间构建版本中,此限制已放宽。这在 PyTorch 2.7 中应该可以直接使用。

可提示的图像分割#

图像分割 是一种计算机视觉技术,它根据图像特征将数字图像划分为不同的像素组或片段。分割一切模型 (SAM) 引入了可提示的图像分割,它根据指示所需对象的提示来预测对象掩码。SAM 2 是第一个用于跨图像和视频分割对象的统一模型。SAM2ImagePredictor 类为提示模型提供了简单的接口。该模型可以接收点和框提示作为输入,以及来自前一次预测迭代的掩码。由于 SAM 2 为对象跟踪提供了强大的零样本性能,因此它可用于跟踪场景中的游戏对象。

SAM2ImagePredictor 的 predict 方法中的张量操作是在 _predict 方法中进行的。因此,我们尝试按如下方式导出。

ep = torch.export.export(
    self._predict,
    args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
    kwargs={"return_logits": return_logits},
    strict=False,
)

错误:模型不是 torch.nn.Module 类型#

torch.export 要求模块的类型必须是 torch.nn.Module。然而,我们试图导出的是一个类方法,因此报错。

Traceback (most recent call last):
  File "/sam2/image_predict.py", line 20, in <module>
    masks, scores, _ = predictor.predict(
  File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
    ep = torch.export.export(
  File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
    raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.

解决方案#

我们编写一个继承自 torch.nn.Module 的辅助类,并在该类的 forward 方法中调用 _predict 方法。完整代码可以在 这里 找到。

class ExportHelper(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(_, *args, **kwargs):
        return self._predict(*args, **kwargs)

 model_to_export = ExportHelper()
 ep = torch.export.export(
      model_to_export,
      args=(unnorm_coords, labels, unnorm_box, mask_input,  multimask_output),
      kwargs={"return_logits": return_logits},
      strict=False,
      )

结论#

在本教程中,我们学习了如何通过正确的配置和简单的代码修改来解决挑战,从而使用 torch.export 为常用用例导出模型。一旦能够导出模型,对于服务器,你可以使用 AOTInductor,对于边缘设备,可以使用 ExecuTorchExportedProgram 降低到硬件中。要了解有关 AOTInductor (AOTI) 的更多信息,请参考 AOTI 教程。要了解有关 ExecuTorch 的更多信息,请参考 ExecuTorch 教程