评价此页

使用区域编译减少 AoT 冷启动编译时间#

作者: Sayak Paul, Charles Bensimon, Angela Yi

区域编译教程 中,我们展示了如何在保留(几乎)全部编译优势的同时减少冷启动编译时间。这已针对即时 (JIT) 编译进行了演示。

本教程展示了如何在提前 (AoT) 编译模型时应用类似的原理。如果您不熟悉 AOTInductor 和 torch.export,我们建议您查看 本教程

先决条件#

  • Pytorch 2.6 或更高版本

  • 熟悉区域编译

  • 熟悉 AOTInductor 和 torch.export

设置#

在开始之前,我们需要安装 torch,如果尚未安装。

pip install torch

步骤#

在本教程中,我们将遵循上述区域编译教程中的相同步骤

  1. 导入所有必要的库。

  2. 定义并初始化具有重复区域的神经网络。

  3. 测量完整模型和使用 AoT 进行区域编译的编译时间。

首先,让我们导入加载数据所需的库

import torch
torch.set_grad_enabled(False)

from time import perf_counter

定义神经网络#

我们将使用与区域编译教程相同的神经网络结构。

我们将使用一个由重复层组成的网络。这模拟了一个大型语言模型,它通常由许多 Transformer 块组成。在本教程中,我们将使用 nn.Module 类创建一个 Layer 作为重复区域的代理。然后,我们将创建一个由 64 个此类 Layer 实例组成的 Model

class Layer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.relu1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(10, 10)
        self.relu2 = torch.nn.ReLU()

    def forward(self, x):
        a = self.linear1(x)
        a = self.relu1(a)
        a = torch.sigmoid(a)
        b = self.linear2(a)
        b = self.relu2(b)
        return b


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)
        self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])

    def forward(self, x):
        # In regional compilation, the self.linear is outside of the scope of ``torch.compile``.
        x = self.linear(x)
        for layer in self.layers:
            x = layer(x)
        return x

提前编译模型#

由于我们是提前编译模型,因此需要准备代表性的输入示例,我们期望模型在实际部署期间会看到这些示例。

让我们创建一个 Model 实例,并为其提供一些样本输入数据。

model = Model().cuda()
input = torch.randn(10, 10, device="cuda")
output = model(input)
print(f"{output.shape=}")
output.shape=torch.Size([10, 10])

现在,让我们提前编译我们的模型。我们将使用上面创建的 input 传递给 torch.export。这将产生一个 torch.export.ExportedProgram,我们可以对其进行编译。

/usr/local/lib/python3.10/dist-packages/torch/backends/cuda/__init__.py:131: UserWarning:

Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.ac.cn/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)

/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:312: UserWarning:

TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.

我们可以从该 path 加载并使用它来执行推理。

compiled_binary = torch._inductor.aoti_load_package(path)
output_compiled = compiled_binary(input)
print(f"{output_compiled.shape=}")
output_compiled.shape=torch.Size([10, 10])

提前编译模型的 _区域_#

另一方面,提前编译模型区域需要一些关键的更改。

由于计算模式被模型中所有重复的块(在本例中为 Layer 实例)共享,因此我们可以仅编译一个块,然后让 inductor 重用它。

model = Model().cuda()
path = torch._inductor.aoti_compile_and_package(
    torch.export.export(model.layers[0], args=(input,)),
    inductor_configs={
        # compile artifact w/o saving params in the artifact
        "aot_inductor.package_constants_in_so": False,
    }
)

导出的程序(torch.export.ExportedProgram)包含张量计算、一个 state_dict,其中包含所有提升的参数和缓冲区的张量值以及其他元数据。我们将 aot_inductor.package_constants_in_so 设置为 False,以避免在生成的工件中序列化模型参数。

现在,在加载编译后的二进制文件时,我们可以重用每个块的现有参数。这使我们能够利用上面获得的编译后的二进制文件。

for layer in model.layers:
    compiled_layer = torch._inductor.aoti_load_package(path)
    compiled_layer.load_constants(
        layer.state_dict(), check_full_update=True, user_managed=True
    )
    layer.forward = compiled_layer

output_regional_compiled = model(input)
print(f"{output_regional_compiled.shape=}")
output_regional_compiled.shape=torch.Size([10, 10])

与 JIT 区域编译一样,在模型内部提前编译区域可以显著减少冷启动时间。实际数字会因模型而异。

尽管完整模型编译提供了最广泛的优化范围,但出于实际目的,并且取决于模型类型,我们已经看到区域编译(JIT 和 AoT)提供了相似的速度优势,同时极大地减少了冷启动时间。

测量编译时间#

接下来,让我们测量完整模型和区域编译的编译时间。

def measure_compile_time(input, regional=False):
    start = perf_counter()
    model = aot_compile_load_model(regional=regional)
    torch.cuda.synchronize()
    end = perf_counter()
    # make sure the model works.
    _ = model(input)
    return end - start

def aot_compile_load_model(regional=False) -> torch.nn.Module:
    input = torch.randn(10, 10, device="cuda")
    model = Model().cuda()

    inductor_configs = {}
    if regional:
        inductor_configs = {"aot_inductor.package_constants_in_so": False}

    # Reset the compiler caches to ensure no reuse between different runs
    torch.compiler.reset()
    with torch._inductor.utils.fresh_inductor_cache():
        path = torch._inductor.aoti_compile_and_package(
            torch.export.export(
                model.layers[0] if regional else model,
                args=(input,)
            ),
            inductor_configs=inductor_configs,
        )

        if regional:
            for layer in model.layers:
                compiled_layer = torch._inductor.aoti_load_package(path)
                compiled_layer.load_constants(
                    layer.state_dict(), check_full_update=True, user_managed=True
                )
                layer.forward = compiled_layer
        else:
            model = torch._inductor.aoti_load_package(path)
    return model

input = torch.randn(10, 10, device="cuda")
full_model_compilation_latency = measure_compile_time(input, regional=False)
print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")

regional_compilation_latency = measure_compile_time(input, regional=True)
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")

assert regional_compilation_latency < full_model_compilation_latency
Full model compilation time = 11.46 seconds
Regional compilation time = 4.76 seconds

模型中也可能存在与编译不兼容的层。因此,完整编译将导致计算图碎片化,从而可能导致延迟下降。在这种情况下,区域编译可能是有益的。

结论#

本教程展示了如何在提前编译模型时控制冷启动时间。当模型具有重复块时,这会变得有效,这在大型生成模型中通常会看到。我们在各种模型上使用了此教程来加速实时性能。在此处了解更多信息:here

脚本总运行时间: (0 分 42.500 秒)