• 文档 >
  • 通过 Inductor 使用 Intel GPU 后端进行 PyTorch 2 导出量化
快捷方式

PyTorch 2 通过 Inductor 使用 Intel GPU 后端进行导出量化

作者: 颜志伟, 王艺康, 张良刚, 刘河, 崔逸峰

先决条件

介绍

本教程介绍 XPUInductorQuantizer,它旨在为 Intel GPU 上的推理提供量化模型。XPUInductorQuantizer 使用 PyTorch Export Quantization 流程并将量化模型降低到 Inductor 中。

PyTorch 2 Export Quantization 流程使用 torch.export 将模型捕获到图中并在 ATen 图之上执行量化转换。这种方法预计将显著提高模型覆盖率,同时具有更好的可编程性和简化的用户体验。TorchInductor 是一个编译器后端,它将 TorchDynamo 生成的 FX 图转换为优化的 C++/Triton 内核。

量化流程分为三个步骤

  • 步骤 1:基于 torch export 机制从 Eager 模型捕获 FX 图。

  • 步骤 2:基于捕获的 FX 图应用量化流程,包括定义后端特定量化器、生成带观测器的准备模型、执行准备模型的校准以及将准备模型转换为量化模型。

  • 步骤 3:使用 API torch.compile 将量化模型降低到 Inductor 中,这将调用 Triton 内核或 oneDNN GEMM/卷积内核。

这个流程的高级架构可能如下所示

float_model(Python)                          Example Input
    \                                              /
     \                                            /
—--------------------------------------------------------
|                         export                       |
—--------------------------------------------------------
                            |
                    FX Graph in ATen
                            |            X86InductorQuantizer
                            |                 /
—--------------------------------------------------------
|                      prepare_pt2e                     |
|                           |                           |
|                     Calibrate/Train                   |
|                           |                           |
|                      convert_pt2e                     |
—--------------------------------------------------------
                            |
                     Quantized Model
                            |
—--------------------------------------------------------
|                    Lower into Inductor                |
—--------------------------------------------------------
                            |
       OneDNN kernels                Triton Kernels

训练后量化

我们目前只支持静态量化。

建议通过 Intel GPU 渠道安装以下依赖项

pip3 install torch torchvision torchaudio pytorch-triton-xpu --index-url https://download.pytorch.org/whl/xpu

请注意,由于 Inductor freeze 功能尚未默认开启,您必须使用 TORCHINDUCTOR_FREEZING=1 运行您的示例代码。

例如

TORCHINDUCTOR_FREEZING=1 python xpu_inductor_quantizer_example.py

1. 捕获 FX 图

我们将首先执行必要的导入,从即时模式(eager)模块中捕获 FX 图。

import torch
import torchvision.models as models
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.xpu_inductor_quantizer as xpuiq
from torchao.quantization.pt2e.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer
from torch.export import export

# Create the Eager Model
model_name = "resnet18"
model = models.__dict__[model_name](weights=models.ResNet18_Weights.DEFAULT)

# Set the model to eval mode
model = model.eval().to("xpu")

# Create the data, using the dummy data here as an example
traced_bs = 50
x = torch.randn(traced_bs, 3, 224, 224, device="xpu").contiguous(memory_format=torch.channels_last)
example_inputs = (x,)

# Capture the FX Graph to be quantized
with torch.no_grad():
    exported_model = export(
        model,
        example_inputs,
    ).module()

接下来,我们将量化 FX 模块。

2. 应用量化

捕获 FX 模块后,我们将导入 Intel GPU 的后端量化器并对其进行配置以量化模型。

quantizer = XPUInductorQuantizer()
quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config())

XPUInductorQuantizer 中的默认量化配置对激活和权重都使用有符号 8 位。张量是按张量量化的,而权重是有符号 8 位按通道量化的。

此外,除了使用非对称量化激活的默认量化配置外,还支持有符号 8 位对称量化激活,这有可能提供更好的性能。

from torchao.quantization.pt2e.observer import HistogramObserver, PerChannelMinMaxObserver
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationSpec
from torchao.quantization.pt2e.quantizer import QuantizationConfig
from typing import Any, Optional, TYPE_CHECKING
if TYPE_CHECKING:
    from torchao.quantization.pt2e import ObserverOrFakeQuantizeConstructor
def get_xpu_inductor_symm_quantization_config():
    extra_args: dict[str, Any] = {"eps": 2**-12}
    act_observer_or_fake_quant_ctr = HistogramObserver
    act_quantization_spec = QuantizationSpec(
        dtype=torch.int8,
        quant_min=-128,
        quant_max=127,
        qscheme=torch.per_tensor_symmetric,  # Change the activation quant config to symmetric
        is_dynamic=False,
        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
            **extra_args
        ),
    )

    weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
        PerChannelMinMaxObserver
    )

    weight_quantization_spec = QuantizationSpec(
        dtype=torch.int8,
        quant_min=-128,
        quant_max=127,
        qscheme=torch.per_channel_symmetric, # Same as the default config, the only supported option for weight
        ch_axis=0,  # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
        is_dynamic=False,
        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
            **extra_args
        ),
    )

    bias_quantization_spec = None  # will use placeholder observer by default
    quantization_config = QuantizationConfig(
        act_quantization_spec,
        act_quantization_spec,
        weight_quantization_spec,
        bias_quantization_spec,
        False,
    )
    return quantization_config

# Then, set the quantization configuration to the quantizer.
quantizer = XPUInductorQuantizer()
quantizer.set_global(get_xpu_inductor_symm_quantization_config())

导入后端特定量化器后,准备模型以进行训练后量化。prepare_pt2eBatchNorm 运算符折叠到前置的 Conv2d 运算符中,并将观测器插入到模型中的适当位置。

prepared_model = prepare_pt2e(exported_model, quantizer)

(仅适用于静态量化)在将观测器插入模型后,校准 prepared_model

# We use the dummy data as an example here
prepared_model(*example_inputs)

# Alternatively: user can define the dataset to calibrate
# def calibrate(model, data_loader):
#     model.eval()
#     with torch.no_grad():
#         for image, target in data_loader:
#             model(image)
# calibrate(prepared_model, data_loader_test)  # run calibration on sample data

最后,将校准后的模型转换为量化模型。convert_pt2e 接受一个校准后的模型并生成一个量化模型。

converted_model = convert_pt2e(prepared_model)

完成这些步骤后,量化流程就完成了,量化模型也可用了。

3. 降低到 Inductor 中

然后,量化模型将被降低到 Inductor 后端。

with torch.no_grad():
    optimized_model = torch.compile(converted_model)

    # Running some benchmark
    optimized_model(*example_inputs)

在更高级的场景中,int8-混合-bf16 量化开始发挥作用。在这种情况下,卷积或 GEMM 运算符在没有后续量化节点的情况下以 BFloat16 而不是 Float32 生成输出。随后,BFloat16 张量无缝地通过后续的点式运算符传播,有效地最小化内存使用并可能提高性能。此功能的使用与常规 BFloat16 Autocast 的使用类似,就像将脚本包装在 BFloat16 Autocast 上下文内一样简单。

with torch.amp.autocast(device_type="xpu", dtype=torch.bfloat16), torch.no_grad():
        # Turn on Autocast to use int8-mixed-bf16 quantization. After lowering into indcutor backend,
        # For operators such as QConvolution and QLinear:
        # * The input data type is consistently defined as int8, attributable to the presence of a pair
        #    of quantization and dequantization nodes inserted at the input.
        # * The computation precision remains at int8.
        # * The output data type may vary, being either int8 or BFloat16, contingent on the presence
        #   of a pair of quantization and dequantization nodes at the output.
        # For non-quantizable pointwise operators, the data type will be inherited from the previous node,
        # potentially resulting in a data type of BFloat16 in this scenario.
        # For quantizable pointwise operators such as QMaxpool2D, it continues to operate with the int8
        # data type for both input and output.
        optimized_model = torch.compile(converted_model)

        # Running some benchmark
        optimized_model(*example_inputs)

结论

在本教程中,我们学习了如何利用 XPUInductorQuantizer 对模型执行训练后量化,以在 Intel GPU 上进行推理,并利用 PyTorch 2 的 Export Quantization 流程。我们涵盖了捕获 FX 图、应用量化以及使用 torch.compile 将量化模型降低到 Inductor 后端的逐步过程。此外,我们还探讨了使用 int8-混合-bf16 量化的好处,以提高内存效率和潜在的性能增益,尤其是在使用 BFloat16 自动混合精度时。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源