torch.export 流、常见挑战及解决方案演示#
作者: Ankith Gunapal, Jordi Ramon, Marcos Carranza
在 torch.export 入门教程 中,我们学习了如何使用 torch.export。本教程在前一教程的基础上进行了扩展,通过代码演示了导出流行模型的流程,并解决了使用 torch.export 时可能遇到的常见挑战。
在本教程中,您将学习如何针对以下用例导出模型
视频分类器(MViT)
自动语音识别(OpenAI Whisper-Tiny)
图像字幕生成(BLIP)
可提示图像分割(SAM2)
选择这四种模型是为了演示 torch.export 的独特功能,以及在实现过程中遇到的一些实际考虑和问题。
先决条件#
PyTorch 2.4 或更高版本
对
torch.export和 PyTorch Eager 推理有基本了解。
torch.export 的关键要求:无图中断#
torch.compile 通过使用 JIT 将 PyTorch 代码编译为优化内核来加速 PyTorch 代码。它使用 TorchDynamo 优化给定模型,并创建一个优化的图,然后使用 API 中指定的后端将其降低到硬件。当 TorchDynamo 遇到不支持的 Python 功能时,它会中断计算图,让默认的 Python 解释器处理不支持的代码,然后恢复捕获图。计算图中的这种中断称为图中断。
torch.export 和 torch.compile 的一个关键区别是,torch.export 不支持图中断,这意味着您要导出的整个模型或模型的一部分需要是一个单一的图。这是因为处理图中断涉及使用默认的 Python 评估来解释不支持的操作,这与 torch.export 的设计目的不兼容。您可以在此 链接 中阅读有关各种 PyTorch 框架之间差异的详细信息。
您可以使用以下命令识别程序中的图中断
TORCH_LOGS="graph_breaks" python <file_name>.py
您需要修改程序以消除图中断。一旦解决,您就可以导出模型了。PyTorch 对流行的 HuggingFace 和 TIMM 模型运行 torch.compile 的夜间基准测试。其中大多数模型都没有图中断。
此配方中的模型没有图中断,但会因 torch.export 而失败。
视频分类#
MViT 是一类基于 MultiScale Vision Transformers 的模型。该模型已使用 Kinetics-400 数据集 针对视频分类进行了训练。该模型及其相关数据集可用于游戏场景中的动作识别。
下面的代码通过使用 batch_size=2 进行追踪来导出 MViT,然后检查导出的程序是否可以使用 batch_size=4 运行。
import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
# Export the model.
exported_program = torch.export.export(
model,
(input_frames,),
)
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
exported_program.module()(input_frames)
except Exception:
tb.print_exc()
错误:静态批次大小#
raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4
默认情况下,导出流程会假设所有输入形状都是静态的,因此如果您使用与追踪时使用的输入形状不同的输入形状运行程序,您将遇到错误。
解决方案#
为了解决此错误,我们指定输入的第一维(batch_size)为动态,指定了预期的 batch_size 范围。在下面显示的更正示例中,我们指定预期的 batch_size 范围可以是从 1 到 16。一个需要注意的细节是 min=2 并非错误,这一点在 0/1 特殊化问题 中得到了解释。有关 torch.export 动态形状的详细描述,请参阅导出教程。下面提供的代码演示了如何导出具有动态批次大小的 mViT。
import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
# Export the model.
batch_dim = torch.export.Dim("batch", min=2, max=16)
exported_program = torch.export.export(
model,
(input_frames,),
# Specify the first dimension of the input x as dynamic
dynamic_shapes={"x": {0: batch_dim}},
)
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
exported_program.module()(input_frames)
except Exception:
tb.print_exc()
自动语音识别#
自动语音识别 (ASR) 是利用机器学习将口语转换为文本。 Whisper 是来自 OpenAI 的基于 Transformer 的编码器-解码器模型,它在 680,000 小时的 ASR 和语音翻译标记数据上进行了训练。下面的代码尝试导出用于 ASR 的 whisper-tiny 模型。
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id
model.eval()
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,))
错误:TorchDynamo 严格追踪#
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache'
默认情况下,torch.export 使用 TorchDynamo(一个字节码分析引擎)追踪您的代码,该引擎会符号化地分析您的代码并构建图。这种分析提供了更强的安全性保证,但并非所有 Python 代码都支持。当我们使用默认的严格模式导出 whisper-tiny 模型时,由于存在不支持的功能,它通常会在 Dynamo 中返回错误。要理解为什么这会在 Dynamo 中导致错误,您可以参考此 GitHub issue。
解决方案#
为了解决上述错误,torch.export 支持 non_strict 模式,在该模式下,程序使用 Python 解释器进行追踪,这与 PyTorch Eager 执行类似。唯一的区别是所有 Tensor 对象都将被 ProxyTensors 替换,后者会将它们的所有操作记录到图中。通过使用 strict=False,我们可以成功导出程序。
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id
model.eval()
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False)
图像字幕生成#
图像字幕生成 是用文字定义图像内容的任务。在游戏场景中,图像字幕生成可以通过动态生成场景中各种游戏对象的文本描述来增强游戏体验,从而为玩家提供更多细节。 BLIP 是由 Salesforce Research 发布的、流行的图像字幕生成模型。下面的代码尝试使用 batch_size=1 导出 BLIP。
import torch
from models.blip import blip_decoder
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
image = torch.randn(1, 3,384,384).to(device)
caption_input = ""
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)
错误:无法修改具有冻结存储的张量#
在导出模型时,它可能会失败,因为模型实现可能包含某些 Python 操作,而这些操作尚未得到 torch.export 的支持。其中一些失败可能有解决方法。BLIP 就是一个例子,原始模型会失败,但可以通过对代码进行少量更改来解决。 torch.export 在 ExportDB 中列出了支持和不支持操作的常见情况,并展示了如何修改代码以使其兼容导出。
File "/BLIP/models/blip.py", line 112, in forward
text.input_ids[:,0] = self.tokenizer.bos_token_id
File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
outs_unwrapped = func._op_dk(
RuntimeError: cannot mutate tensors with frozen storage
解决方案#
克隆导出失败的张量。
text.input_ids = text.input_ids.clone() # clone the tensor
text.input_ids[:,0] = self.tokenizer.bos_token_id
注意
此限制已在 PyTorch 2.7 夜间版本中放宽。这应该可以直接在 PyTorch 2.7 中正常工作。
可提示图像分割#
图像分割 是一种计算机视觉技术,它根据像素的特征将数字图像划分为不同的组,即段。 Segment Anything Model (SAM) 引入了可提示图像分割,它可以根据指示所需对象的提示来预测对象掩码。 SAM 2 是首个用于跨图像和视频分割对象的统一模型。 SAM2ImagePredictor 类提供了模型用于提示模型的简单接口。该模型可以接受点提示和框提示以及前一迭代预测的掩码作为输入。由于 SAM2 对对象跟踪提供了强大的零样本性能,因此可用于跟踪场景中的游戏对象。
在 SAM2ImagePredictor 的 predict 方法中的张量操作发生在 _predict 方法中。因此,我们尝试像这样导出。
ep = torch.export.export(
self._predict,
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
kwargs={"return_logits": return_logits},
strict=False,
)
错误:模型不是 torch.nn.Module 类型#
torch.export 需要模块是 torch.nn.Module 类型。但是,我们尝试导出的模块是一个类方法。因此会报错。
Traceback (most recent call last):
File "/sam2/image_predict.py", line 20, in <module>
masks, scores, _ = predictor.predict(
File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
ep = torch.export.export(
File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.
解决方案#
我们编写了一个继承自 torch.nn.Module 的辅助类,并在该类的 forward 方法中调用 _predict method。完整的代码可以在 这里 找到。
class ExportHelper(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(_, *args, **kwargs):
return self._predict(*args, **kwargs)
model_to_export = ExportHelper()
ep = torch.export.export(
model_to_export,
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
kwargs={"return_logits": return_logits},
strict=False,
)
结论#
在本教程中,我们学习了如何通过正确的配置和简单的代码修改来解决挑战,从而使用 torch.export 导出流行用例的模型。一旦您能够导出模型,就可以将 ExportedProgram 降低到服务器端的 AOTInductor 或边缘设备端的 ExecuTorch。要了解有关 AOTInductor (AOTI) 的更多信息,请参阅 AOTI 教程。要了解有关 ExecuTorch 的更多信息,请参阅 ExecuTorch 教程。