• 文档 >
  • 导出自定义 LLM
快捷方式

导出自定义 LLM

如果您有自己的 PyTorch 模型,该模型是一个 LLM,本指南将向您展示如何手动导出和降低到 ExecuTorch,其中包含与之前 export_llm 指南中涵盖的许多相同优化。

本示例使用 Karpathy 的 nanoGPT,这是一个 GPT-2 124M 的最小实现。本指南适用于其他语言模型,因为 ExecuTorch 是模型无关的。

导出到 ExecuTorch (基础)

导出将 PyTorch 模型转换为可在消费设备上高效运行的格式。

对于此示例,您需要 nanoGPT 模型和相应的 tokenizer 词汇表。

curl https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py -O
curl https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json -O
wget https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py
wget https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json

要将模型转换为针对独立执行优化的格式,有两个步骤。首先,使用 PyTorch 的 export 函数将 PyTorch 模型转换为中间的、平台无关的中间表示。然后使用 ExecuTorch 的 to_edgeto_executorch 方法为设备上的执行准备模型。这将创建一个 .pte 文件,可以在运行时由桌面或移动应用程序加载。

创建一个名为 export_nanogpt.py 的文件,其中包含以下内容

# export_nanogpt.py

import torch

from executorch.exir import EdgeCompileConfig, to_edge
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.export import export, export_for_training

from model import GPT

# Load the model.
model = GPT.from_pretrained('gpt2')

# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long), )

# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/main/concepts#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
    {1: torch.export.Dim("token_dim", max=model.config.block_size)},
)

# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module()
    traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
edge_config = EdgeCompileConfig(_check_ir_validity=False)
edge_manager = to_edge(traced_model,  compile_config=edge_config)
et_program = edge_manager.to_executorch()

# Save the ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
    file.write(et_program.buffer)

要导出,请使用 python export_nanogpt.py (或 python3,取决于您的环境) 运行脚本。它将在当前目录中生成一个 nanogpt.pte 文件。

有关更多信息,请参阅 导出到 ExecuTorchtorch.export

后端委派

虽然 ExecuTorch 为所有运算符提供了可移植的跨平台实现,但它还为许多不同的目标提供了专门的后端。这些包括但不限于:通过 XNNPACK 后端加速 x86 和 ARM CPU,通过 Core ML 后端和 Metal Performance Shader (MPS) 后端加速 Apple 设备,以及通过 Vulkan 后端加速 GPU。

由于优化特定于给定的后端,每个 pte 文件都特定于导出时定位的后端。要支持多个设备,例如 Android 的 XNNPACK 加速和 iOS 的 Core ML,请为每个后端导出单独的 PTE 文件。

要在导出期间将模型委派给特定后端,ExecuTorch 使用 to_edge_transform_and_lower() 函数。此函数接收来自 torch.export 的导出的程序以及特定于后端的 partitioner 对象。partitioner 识别可以由目标后端优化计算图的某些部分。在 to_edge_transform_and_lower() 中,导出的程序被转换为 edge dialect 程序。然后 partitioner 将兼容的图部分委派给后端进行加速和优化。未委派的任何图部分都由 ExecuTorch 的默认运算符实现执行。

要将导出的模型委派给特定后端,我们需要首先从 ExecuTorch 代码库导入其 partitioner 以及 edge compile config,然后调用 to_edge_transform_and_lower

以下是如何将 nanoGPT 委派给 XNNPACK 的示例 (如果您例如要部署到 Android 手机)

# export_nanogpt.py

# Load partitioner for Xnnpack backend
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

# Model to be delegated to specific backend should use specific edge compile config
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower

import torch
from torch.export import export
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.export import export_for_training

from model import GPT

# Load the nanoGPT model.
model = GPT.from_pretrained('gpt2')

# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (
        torch.randint(0, 100, (1, model.config.block_size - 1), dtype=torch.long),
    )

# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/main/concepts.html#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
    {1: torch.export.Dim("token_dim", max=model.config.block_size - 1)},
)

# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module()
    traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config
edge_config = get_xnnpack_edge_compile_config()
# Converted to edge program and then delegate exported model to Xnnpack backend
# by invoking `to` function with Xnnpack partitioner.
edge_manager = to_edge_transform_and_lower(traced_model, partitioner = [XnnpackPartitioner()], compile_config = edge_config)
et_program = edge_manager.to_executorch()

# Save the Xnnpack-delegated ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
    file.write(et_program.buffer)

量化

量化是指使用低精度类型运行计算和存储张量的技术集合。与 32 位浮点相比,使用 8 位整数可以同时实现显著的速度提升和内存使用量的减少。量化模型的方法有很多,在所需预处理量、使用的数据类型以及对模型准确性和性能的影响方面有所不同。

由于移动设备上的计算和内存受到高度限制,因此需要某种形式的量化才能在消费电子产品上发布大型模型。特别是,大型语言模型,如 Llama2,可能需要将模型权重量化到 4 位或更低。

利用量化需要先转换模型,然后再导出。PyTorch 提供 pt2e (PyTorch 2 Export) API 来实现此目的。本示例以 XNNPACK 委派为目标进行 CPU 加速。因此,它需要使用特定于 XNNPACK 的量化器。定位不同的后端将需要使用相应的量化器。

要使用 8 位整数动态量化与 XNNPACK 委派,请调用 prepare_pt2e,通过使用代表性输入运行来校准模型,然后调用 convert_pt2e。这将更新计算图以使用可用的量化运算符。

# export_nanogpt.py

from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
    DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
# Use dynamic, per-channel quantization.
xnnpack_quant_config = get_symmetric_quantization_config(
    is_per_channel=True, is_dynamic=True
)
xnnpack_quantizer = XNNPACKQuantizer()
xnnpack_quantizer.set_global(xnnpack_quant_config)

m = export_for_training(model, example_inputs).module()

# Annotate the model for quantization. This prepares the model for calibration.
m = prepare_pt2e(m, xnnpack_quantizer)

# Calibrate the model using representative inputs. This allows the quantization
# logic to determine the expected range of values in each tensor.
m(*example_inputs)

# Perform the actual quantization.
m = convert_pt2e(m, fold_quantize=False)
DuplicateDynamicQuantChainPass()(m)

traced_model = export(m, example_inputs)

此外,添加或更新 to_edge_transform_and_lower() 调用以使用 XnnpackPartitioner。这指示 ExecuTorch 通过 XNNPACK 后端为 CPU 执行优化模型。

from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
    XnnpackPartitioner,
)
edge_config = get_xnnpack_edge_compile_config()
# Convert to edge dialect and lower to XNNPack.
edge_manager = to_edge_transform_and_lower(traced_model, partitioner = [XnnpackPartitioner()], compile_config = edge_config)
et_program = edge_manager.to_executorch()

with open("nanogpt.pte", "wb") as file:
    file.write(et_program.buffer)

有关更多信息,请参阅 ExecuTorch 中的量化

性能分析和调试

在调用 to_edge_transform_and_lower() 来降低模型后,您可能想查看哪些部分被委派了,哪些没有。ExecuTorch 提供实用方法来洞察委派情况。您可以使用此信息来了解底层计算并诊断潜在的性能问题。模型作者可以使用此信息以与目标后端兼容的方式来构建模型。

可视化委派

get_delegation_info() 方法提供了对调用 to_edge_transform_and_lower() 后模型状态的摘要。

from executorch.devtools.backend_debug import get_delegation_info
from tabulate import tabulate

# ... After call to to_edge_transform_and_lower(), but before to_executorch()
graph_module = edge_manager.exported_program().graph_module
delegation_info = get_delegation_info(graph_module)
print(delegation_info.get_summary())
df = delegation_info.get_operator_delegation_dataframe()
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))

对于以 XNNPACK 后端为目标的 nanoGPT,您可能会看到以下内容 (请注意,下面的数字仅用于说明目的,实际值可能有所不同)

Total  delegated  subgraphs:  145
Number  of  delegated  nodes:  350
Number  of  non-delegated  nodes:  760

op_type

# in_delegated_graphs

# in_non_delegated_graphs

0

aten__softmax_default

12

0

1

aten_add_tensor

37

0

2

aten_addmm_default

48

0

3

aten_any_dim

0

12

25

aten_view_copy_default

96

122

30

Total

350

760

从表中可以看出,运算符 aten_view_copy_default 在委派图中有 96 个实例,在非委派图中有 122 个实例。要查看更详细的视图,请使用 format_delegated_graph() 方法获取整个图的格式化字符串打印输出,或者使用 print_delegated_graph() 直接打印。

from executorch.exir.backend.utils import format_delegated_graph
graph_module = edge_manager.exported_program().graph_module
print(format_delegated_graph(graph_module))

对于大型模型,这可能会产生大量输出。考虑使用“Control+F”或“Command+F”来查找您感兴趣的运算符 (例如,“aten_view_copy_default”)。观察哪些实例不在已降低的图下。

在下面的 nanoGPT 输出片段中,可以看到 transformer 模块已委派给 XNNPACK,而 where 运算符则没有。

%aten_where_self_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default_33, %scalar_tensor_23, %scalar_tensor_22), kwargs = {})
%lowered_module_144 : [num_users=1] = get_attr[target=lowered_module_144]
backend_id: XnnpackBackend
lowered graph():
    %p_transformer_h_0_attn_c_attn_weight : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_weight]
    %p_transformer_h_0_attn_c_attn_bias : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_bias]
    %getitem : [num_users=1] = placeholder[target=getitem]
    %sym_size : [num_users=2] = placeholder[target=sym_size]
    %aten_view_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%getitem, [%sym_size, 768]), kwargs = {})
    %aten_permute_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%p_transformer_h_0_attn_c_attn_weight, [1, 0]), kwargs = {})
    %aten_addmm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.addmm.default](args = (%p_transformer_h_0_attn_c_attn_bias, %aten_view_copy_default, %aten_permute_copy_default), kwargs = {})
    %aten_view_copy_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_addmm_default, [1, %sym_size, 2304]), kwargs = {})
    return [aten_view_copy_default_1]

进一步的模型分析和调试

通过 ExecuTorch 的开发者工具,用户可以分析模型执行,提供模型中每个运算符的计时信息,进行模型数值调试等。

ETRecord 是在导出时生成的工件,其中包含模型图和源级元数据,将 ExecuTorch 程序与原始 PyTorch 模型链接起来。您可以查看所有分析事件,而无需 ETRecord,但有了 ETRecord,您还可以将每个事件链接到正在执行的运算符类型、模块层次结构以及原始 PyTorch 源代码的堆栈跟踪。有关更多信息,请参阅 ETRecord 文档

在您的导出脚本中,在调用 to_edge()to_executorch() 之后,调用 generate_etrecord(),并传入来自 to_edge()EdgeProgramManager 和来自 to_executorch()ExecuTorchProgramManager。请务必复制 EdgeProgramManager,因为调用 to_edge_transform_and_lower() 会就地修改图。

# export_nanogpt.py

import copy
from executorch.devtools import generate_etrecord

# Make the deep copy immediately after to to_edge()
edge_manager_copy = copy.deepcopy(edge_manager)

# ...
# Generate ETRecord right after to_executorch()
etrecord_path = "etrecord.bin"
generate_etrecord(etrecord_path, edge_manager_copy, et_program)

运行导出脚本,ETRecord 将被生成为 etrecord.bin

要了解有关 ExecuTorch 开发者工具的更多信息,请参阅 ExecuTorch 开发者工具简介

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源