评价此页

通过区域编译减少 AoT 冷启动编译时间#

作者: Sayak Paul, Charles Bensimon, Angela Yi

区域编译方案中,我们展示了如何在保持(几乎)全部编译优势的同时,减少冷启动编译时间。该方法是针对即时(JIT)编译进行演示的。

本方案展示了在预先(AoT)编译模型时如何应用类似的原则。如果您不熟悉 AOTInductor 和 torch.export,建议您查看此教程

先决条件#

  • Pytorch 2.6 或更高版本

  • 熟悉区域编译

  • 熟悉 AOTInductor 和 torch.export

设置#

在开始之前,如果尚未安装 torch,我们需要先进行安装。

pip install torch

步骤#

在本方案中,我们将遵循上述区域编译方案中的相同步骤

  1. 导入所有必要的库。

  2. 定义并初始化一个具有重复区域的神经网络。

  3. 测量完整模型与采用 AoT 区域编译的编译时间。

首先,让我们导入加载数据所需的必要库

import torch
torch.set_grad_enabled(False)

from time import perf_counter

定义神经网络#

我们将使用与区域编译方案中相同的神经网络结构。

我们将使用一个由重复层组成的网络。这模拟了一个典型由许多 Transformer 块组成的大型语言模型。在本方案中,我们将使用 nn.Module 类创建一个 Layer 作为重复区域的代理。然后,我们将创建一个由该 Layer 类的 64 个实例组成的 Model

class Layer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.relu1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(10, 10)
        self.relu2 = torch.nn.ReLU()

    def forward(self, x):
        a = self.linear1(x)
        a = self.relu1(a)
        a = torch.sigmoid(a)
        b = self.linear2(a)
        b = self.relu2(b)
        return b


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)
        self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])

    def forward(self, x):
        # In regional compilation, the self.linear is outside of the scope of ``torch.compile``.
        x = self.linear(x)
        for layer in self.layers:
            x = layer(x)
        return x

预先编译模型 (AoT)#

由于我们是预先编译模型,因此需要准备好我们期望模型在实际部署过程中会见到的代表性输入示例。

让我们创建一个 Model 实例并传入一些示例输入数据。

model = Model().cuda()
input = torch.randn(10, 10, device="cuda")
output = model(input)
print(f"{output.shape=}")
output.shape=torch.Size([10, 10])

现在,让我们预先编译我们的模型。我们将使用上面创建的 input 传给 torch.export。这将产生一个我们可以编译的 torch.export.ExportedProgram

/usr/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:320: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(

我们可以从该 path 加载并使用它来执行推理。

compiled_binary = torch._inductor.aoti_load_package(path)
output_compiled = compiled_binary(input)
print(f"{output_compiled.shape=}")
output_compiled.shape=torch.Size([10, 10])

预先编译模型的_区域_#

另一方面,预先编译模型区域需要一些关键的更改。

由于计算模式由模型中所有重复的块(本例中为 Layer 实例)共享,我们只需编译一个单一的块,并让 Inductor 对其进行复用。

model = Model().cuda()
path = torch._inductor.aoti_compile_and_package(
    torch.export.export(model.layers[0], args=(input,)),
    inductor_configs={
        # compile artifact w/o saving params in the artifact
        "aot_inductor.package_constants_in_so": False,
    }
)
/usr/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)

导出的程序(torch.export.ExportedProgram)包含张量计算、一个包含所有提升参数和缓存张量值的 state_dict,以及其他元数据。我们将 aot_inductor.package_constants_in_so 指定为 False,以便不在生成的工件中序列化模型参数。

现在,在加载编译好的二进制文件时,我们可以复用每个块的现有参数。这使我们能够利用上面获得的编译二进制文件。

for layer in model.layers:
    compiled_layer = torch._inductor.aoti_load_package(path)
    compiled_layer.load_constants(
        layer.state_dict(), check_full_update=True, user_managed=True
    )
    layer.forward = compiled_layer

output_regional_compiled = model(input)
print(f"{output_regional_compiled.shape=}")
output_regional_compiled.shape=torch.Size([10, 10])

就像 JIT 区域编译一样,预先编译模型内的区域可以显著减少冷启动时间。具体数值会因模型而异。

尽管完整模型编译提供了最全面的优化范围,但出于实际目的,并且根据模型类型的不同,我们观察到区域编译(无论是 JiT 还是 AoT)都能提供相似的速度优势,同时大幅缩短了冷启动时间。

测量编译时间#

接下来,让我们测量完整模型和区域编译的编译时间。

def measure_compile_time(input, regional=False):
    start = perf_counter()
    model = aot_compile_load_model(regional=regional)
    torch.cuda.synchronize()
    end = perf_counter()
    # make sure the model works.
    _ = model(input)
    return end - start

def aot_compile_load_model(regional=False) -> torch.nn.Module:
    input = torch.randn(10, 10, device="cuda")
    model = Model().cuda()

    inductor_configs = {}
    if regional:
        inductor_configs = {"aot_inductor.package_constants_in_so": False}

    # Reset the compiler caches to ensure no reuse between different runs
    torch.compiler.reset()
    with torch._inductor.utils.fresh_inductor_cache():
        path = torch._inductor.aoti_compile_and_package(
            torch.export.export(
                model.layers[0] if regional else model,
                args=(input,)
            ),
            inductor_configs=inductor_configs,
        )

        if regional:
            for layer in model.layers:
                compiled_layer = torch._inductor.aoti_load_package(path)
                compiled_layer.load_constants(
                    layer.state_dict(), check_full_update=True, user_managed=True
                )
                layer.forward = compiled_layer
        else:
            model = torch._inductor.aoti_load_package(path)
    return model

input = torch.randn(10, 10, device="cuda")
full_model_compilation_latency = measure_compile_time(input, regional=False)
print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")

regional_compilation_latency = measure_compile_time(input, regional=True)
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")

assert regional_compilation_latency < full_model_compilation_latency
/usr/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
Full model compilation time = 11.67 seconds
/usr/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
Regional compilation time = 5.05 seconds

模型中可能存在与编译不兼容的层。因此,完整编译会导致计算图碎片化,从而可能导致延迟下降。在这些情况下,区域编译会更有益处。

结论#

本方案展示了如何在预先编译模型时控制冷启动时间。当您的模型具有重复块时(这在大型生成模型中很常见),这种方法非常有效。我们在各种模型上使用了该方案以提升实时性能。在此了解更多

脚本运行总时长:(0 分 43.378 秒)