评价此页

量化#

创建于:2019 年 10 月 9 日 | 最后更新于:2025 年 6 月 17 日

警告

量化功能处于 Beta 阶段,可能会发生变化。

量化简介#

量化是指以低于浮点精度的位宽执行计算和存储张量的技术。量化模型以降低的精度(而非全精度(浮点)值)对张量执行部分或全部操作。这允许更紧凑的模型表示,并在许多硬件平台上使用高性能矢量化操作。与典型的 FP32 模型相比,PyTorch 支持 INT8 量化,可将模型大小减少 4 倍,内存带宽需求减少 4 倍。INT8 计算的硬件支持通常比 FP32 计算快 2 到 4 倍。量化主要是一种加速推理的技术,量化算子仅支持正向传播。

PyTorch 支持多种量化深度学习模型的方法。在大多数情况下,模型在 FP32 中训练,然后转换为 INT8。此外,PyTorch 还支持量化感知训练,它使用伪量化模块在正向和反向传播中模拟量化误差。请注意,所有计算都在浮点中进行。在量化感知训练结束时,PyTorch 提供转换函数,将训练后的模型转换为较低精度。

在较低级别,PyTorch 提供了一种表示量化张量并对其执行操作的方法。它们可用于直接构建全部或部分计算以较低精度执行的模型。提供了更高级别的 API,其中包含将 FP32 模型转换为较低精度的典型工作流,且精度损失最小。

量化 API 摘要#

PyTorch 提供三种不同的量化模式:即时模式量化(Eager Mode Quantization)、FX 图模式量化(FX Graph Mode Quantization)(维护中)和 PyTorch 2 导出量化(PyTorch 2 Export Quantization)。

即时模式量化是一个 Beta 功能。用户需要手动进行融合并指定量化和反量化的发生位置,而且它只支持模块而不支持函数。

FX 图模式量化是 PyTorch 中自动化的量化工作流程,目前它是一个原型功能,由于有了 PyTorch 2 导出量化,它目前处于维护模式。它通过增加对函数(functional)的支持并自动化量化过程来改进即时模式量化,尽管人们可能需要重构模型以使其与 FX 图模式量化兼容(可通过 torch.fx 进行符号跟踪)。请注意,FX 图模式量化预计无法在任意模型上工作,因为模型可能无法进行符号跟踪。我们将把它集成到 torchvision 等领域库中,用户将能够使用 FX 图模式量化对类似支持领域库中的模型进行量化。对于任意模型,我们将提供通用指导,但要实际使其工作,用户可能需要熟悉 torch.fx,尤其是如何使模型可符号跟踪。

PyTorch 2 Export Quantization 是新的全图模式量化工作流,在 PyTorch 2.1 中作为原型功能发布。PyTorch 2 正在转向更好的全程序捕获解决方案(torch.export),因为它与 FX Graph Mode Quantization 中使用的程序捕获解决方案 torch.fx.symbolic_trace(在 1.4 万个模型中捕获率为 72.7%)相比,可以捕获更高比例(在 1.4 万个模型中捕获率为 88.8%)的模型。torch.export 仍对某些 Python 结构存在限制,并且需要用户参与才能支持导出模型中的动态性,但总体而言,它比之前的程序捕获解决方案有所改进。PyTorch 2 Export Quantization 是为 torch.export 捕获的模型而构建的,同时兼顾了建模用户和后端开发人员的灵活性和生产力。主要功能包括:(1) 可编程 API,用于配置模型量化方式,可扩展到更多用例;(2) 简化建模用户和后端开发人员的用户体验,因为他们只需与单个对象(Quantizer)交互,即可表达用户量化模型的意图以及后端支持;(3) 可选的参考量化模型表示,可以使用整数运算表示量化计算,更接近于硬件中实际发生的量化计算。

建议量化的新用户首先尝试 PyTorch 2 Export Quantization,如果效果不佳,可以尝试 Eager Mode Quantization。

下表比较了 Eager Mode Quantization、FX Graph Mode Quantization 和 PyTorch 2 Export Quantization 之间的差异。

Eager Mode Quantization

FX Graph Mode Quantization

PyTorch 2 Export Quantization

发布状态

测试版

原型(维护)

原型

运算符融合

手动

自动

自动

量化/反量化放置

手动

自动

自动

量化模块

支持

支持

支持

量化函数/Torch 运算符

手动

自动

支持

支持自定义

有限支持

完全支持

完全支持

量化模式支持

训练后量化:静态、动态、仅权重

量化感知训练:静态

训练后量化:静态、动态、仅权重

量化感知训练:静态

由后端特定量化器定义

输入/输出模型类型

torch.nn.Module

torch.nn.Module (可能需要一些重构才能使模型与 FX Graph Mode Quantization 兼容)

torch.fx.GraphModule (由 torch.export 捕获)

支持三种量化类型

  1. 动态量化(权重被量化,激活值以浮点形式读取/存储,并为计算进行量化)

  2. 静态量化(权重被量化,激活值被量化,训练后需要校准)

  3. 静态量化感知训练(权重被量化,激活值被量化,训练期间对量化数值进行建模)

请参阅我们的《PyTorch 量化简介》博客文章,以更全面地了解这些量化类型之间的权衡。

动态量化和静态量化之间的运算符覆盖范围有所不同,具体如下表所示。

静态量化

动态量化

nn.Linear
nn.Conv1d/2d/3d
Y
Y
Y
N
nn.LSTM

nn.GRU
Y(通过
自定义模块)
N
Y

Y
nn.RNNCell
nn.GRUCell
nn.LSTMCell
N
N
N
Y
Y
Y

nn.EmbeddingBag

Y(激活值为 fp32)

Y

nn.Embedding

Y

Y

nn.MultiheadAttention

Y(通过自定义模块)

不支持

激活

广泛支持

不变,计算保持在 fp32

Eager 模式量化#

有关量化流程的通用介绍,包括不同类型的量化,请参阅通用量化流程

训练后动态量化#

这是最简单的量化形式,其中权重提前量化,但激活在推理过程中动态量化。这适用于模型执行时间主要由从内存加载权重而不是计算矩阵乘法支配的情况。对于小批量大小的 LSTM 和 Transformer 类型模型,情况确实如此。

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                 /
linear_weight_fp32

# dynamically quantized model
# linear and LSTM weights are in int8
previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32
                     /
   linear_weight_int8

PTDQ API 示例

import torch

# define a floating point model
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = self.fc(x)
        return x

# create a model instance
model_fp32 = M()
# create a quantized model instance
model_int8 = torch.ao.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)

要了解有关动态量化的更多信息,请参阅我们的动态量化教程

训练后静态量化#

训练后静态量化 (PTQ static) 对模型的权重和激活进行量化。它尽可能将激活融合到前置层中。它需要使用代表性数据集进行校准,以确定激活的最佳量化参数。训练后静态量化通常在内存带宽和计算节省都很重要时使用,其中 CNN 是典型用例。

在应用训练后静态量化之前,我们可能需要修改模型。请参阅Eager 模式静态量化的模型准备

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                    /
    linear_weight_fp32

# statically quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                    /
  linear_weight_int8

PTSQ API 示例

import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')

# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

要了解有关静态量化的更多信息,请参阅静态量化教程

静态量化的量化感知训练#

量化感知训练 (QAT) 在训练期间对量化的影响进行建模,从而比其他量化方法获得更高的精度。我们可以对静态、动态或仅权重量化进行 QAT。在训练期间,所有计算都以浮点数完成,fake_quant 模块通过钳制和舍入来模拟量化效果,以模拟 INT8 的效果。模型转换后,权重和激活被量化,并且激活尽可能融合到前置层中。它通常与 CNN 结合使用,与静态量化相比可产生更高的精度。

在应用训练后静态量化之前,我们可能需要修改模型。请参阅Eager 模式静态量化的模型准备

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                      /
    linear_weight_fp32

# model with fake_quants for modeling quantization numerics during training
previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32
                           /
   linear_weight_fp32 -- fq

# quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                     /
   linear_weight_int8

QAT API 示例

import torch

# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval for fusion to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')

# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,
    [['conv', 'bn', 'relu']])

# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())

# run the training loop (not shown)
training_loop(model_fp32_prepared)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

要了解有关量化感知训练的更多信息,请参阅QAT 教程

Eager 模式静态量化的模型准备#

目前有必要在 Eager 模式量化之前对模型定义进行一些修改。这是因为当前量化是基于模块进行的。具体来说,对于所有量化技术,用户需要:

  1. 将所有需要输出重新量化(因此具有额外参数)的操作从函数形式转换为模块形式(例如,使用 torch.nn.ReLU 而不是 torch.nn.functional.relu)。

  2. 通过在子模块上分配 .qconfig 属性或指定 qconfig_mapping 来指定模型的哪些部分需要量化。例如,设置 model.conv1.qconfig = None 意味着 model.conv 层将不会被量化,设置 model.linear1.qconfig = custom_qconfig 意味着 model.linear1 的量化设置将使用 custom_qconfig 而不是全局 qconfig。

对于量化激活的静态量化技术,用户还需要执行以下操作:

  1. 指定激活的量化和反量化位置。这通过 QuantStubDeQuantStub 模块完成。

  2. 使用 FloatFunctional 将需要特殊处理以进行量化的张量操作包装到模块中。例如,像 addcat 这样的操作需要特殊处理才能确定输出量化参数。

  3. 融合模块:将操作/模块组合成单个模块以获得更高的精度和性能。这通过 fuse_modules() API 完成,该 API 接受要融合的模块列表。我们目前支持以下融合:[Conv, Relu]、[Conv, BatchNorm]、[Conv, BatchNorm, Relu]、[Linear, Relu]

(原型 - 维护模式) FX 图模式量化#

训练后量化(仅权重、动态和静态)有多种量化类型,其配置通过 qconfig_mappingprepare_fx 函数的一个参数)完成。

FXPTQ API 示例

import torch
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy

model_fp = UserModel()

#
# post training dynamic/weight_only quantization
#

# we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# post training static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# quantization aware training for static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)

请遵循以下教程,了解有关 FX 图模式量化的更多信息

(原型) PyTorch 2 导出量化#

API 示例

import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.export import export_for_training
from torch.ao.quantization.quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 10)

   def forward(self, x):
       return self.linear(x)

# initialize a floating point model
float_model = M().eval()

# define calibration function
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)

# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result should mostly stay the same
m = export_for_training(m, *example_inputs).module()
# we get a model with aten ops

# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
# or prepare_qat_pt2e for Quantization Aware Training
m = prepare_pt2e(m, quantizer)

# run calibration
# calibrate(m, sample_inference_data)
m = convert_pt2e(m)

# Step 3. lowering
# lower to target backend

请按照以下教程开始使用 PyTorch 2 Export Quantization

建模用户

后端开发人员(也请查看所有建模用户文档)

量化堆栈#

量化是将浮点模型转换为量化模型的过程。因此,从高层来看,量化堆栈可以分为两部分:1). 量化模型的构建块或抽象 2). 将浮点模型转换为量化模型的量化流程的构建块或抽象

量化模型#

量化张量#

为了在 PyTorch 中进行量化,我们需要能够在张量中表示量化数据。量化张量允许存储量化数据(表示为 int8/uint8/int32)以及量化参数,如比例和零点。量化张量允许许多有用的操作,使量化算术变得容易,此外还允许以量化格式对数据进行序列化。

PyTorch 支持逐张量和逐通道对称和非对称量化。逐张量意味着张量中的所有值都以相同的方式使用相同的量化参数进行量化。逐通道意味着对于每个维度,通常是张量的通道维度,张量中的值以不同的量化参数进行量化。这允许在将张量转换为量化值时减少误差,因为异常值只会影响其所在的通道,而不是整个张量。

映射通过使用以下方式转换浮点张量来执行

_images/math-quantizer-equation.png

请注意,我们确保浮点中的零在量化后没有误差地表示,从而确保填充等操作不会导致额外的量化误差。

以下是量化张量的几个关键属性

  • QScheme (torch.qscheme):一个枚举,指定我们量化张量的方式

    • torch.per_tensor_affine

    • torch.per_tensor_symmetric

    • torch.per_channel_affine

    • torch.per_channel_symmetric

  • dtype (torch.dtype):量化张量的数据类型

    • torch.quint8

    • torch.qint8

    • torch.qint32

    • torch.float16

  • 量化参数(根据 QScheme 而异):所选量化方式的参数

    • torch.per_tensor_affine 将具有以下量化参数

      • 比例(浮点数)

      • 零点(整数)

    • torch.per_channel_affine 将具有以下量化参数

      • 逐通道比例(浮点数列表)

      • 逐通道零点(整数列表)

      • 轴(整数)

量化和反量化#

模型的输入和输出是浮点张量,但量化模型中的激活是量化的,因此我们需要运算符在浮点张量和量化张量之间进行转换。

  • 量化(浮点 -> 量化)

    • torch.quantize_per_tensor(x, scale, zero_point, dtype)

    • torch.quantize_per_channel(x, scales, zero_points, axis, dtype)

    • torch.quantize_per_tensor_dynamic(x, dtype, reduce_range)

    • to(torch.float16)

  • 反量化(量化 -> 浮点)

    • quantized_tensor.dequantize() - 在 torch.float16 张量上调用 dequantize 会将张量转换回 torch.float

    • torch.dequantize(x)

量化运算符/模块#

  • 量化运算符是接受量化张量作为输入并输出量化张量的运算符。

  • 量化模块是执行量化操作的 PyTorch 模块。它们通常为线性变换和卷积等带权操作定义。

量化引擎#

当执行量化模型时,qengine (torch.backends.quantized.engine) 指定要用于执行的后端。确保 qengine 与量化模型的量化激活和权重的值范围兼容非常重要。

量化流程#

观察器和伪量化#

  • 观察器是用于以下目的的 PyTorch 模块:

    • 收集张量统计信息,例如通过观察器的张量的最小值和最大值

    • 并根据收集到的张量统计信息计算量化参数

  • 伪量化是用于以下目的的 PyTorch 模块:

    • 模拟网络中张量的量化(执行量化/反量化)

    • 它可以根据从观察器收集的统计信息计算量化参数,也可以学习量化参数

QConfig#

  • QConfig 是 Observer 或 FakeQuantize 模块类的命名元组,可以使用 qscheme、dtype 等进行配置。它用于配置运算符的观察方式。

    • 运算符/模块的量化配置

      • 不同类型的观察器/伪量化

      • 数据类型

      • 量化方案

      • quant_min/quant_max:可用于模拟低精度张量

    • 目前支持激活和权重的配置

    • 我们根据为给定运算符或模块配置的 qconfig 插入输入/权重/输出观察器

通用量化流程#

通常,流程如下:

  • prepare

    • 根据用户指定的 qconfig 插入 Observer/FakeQuantize 模块

  • 校准/训练(取决于训练后量化或量化感知训练)

    • 允许观察器收集统计数据或伪量化模块学习量化参数

  • convert

    • 将校准/训练后的模型转换为量化模型

量化有不同的模式,可以从两个方面进行分类:

就我们应用量化流程的位置而言,我们有

  1. 训练后量化(在训练后应用量化,量化参数根据样本校准数据计算)

  2. 量化感知训练(在训练期间模拟量化,以便可以使用训练数据与模型一起学习量化参数)

就我们量化操作符的方式而言,我们可以有

  • 仅权重 量化(仅权重进行静态量化)

  • 动态量化(权重静态量化,激活动态量化)

  • 静态量化(权重和激活都静态量化)

我们可以在同一个量化流程中混合使用不同的操作符量化方式。例如,我们可以进行包含静态和动态量化操作符的训练后量化。

量化支持矩阵#

量化模式支持#

量化模式

数据集要求

最适合

准确性

注意事项

训练后量化

动态/仅权重 量化

激活动态量化(fp16,int8)或未量化,权重静态量化(fp16,int8,in4)

LSTM、MLP、Embedding、Transformer

良好

易于使用,当性能受限于计算或权重导致的内存限制时,接近静态量化

静态量化

激活和权重静态量化(int8)

校准数据集

CNN

良好

提供最佳性能,可能对精度有较大影响,适用于仅支持 int8 计算的硬件

量化感知训练

动态量化

激活和权重都是伪量化的

微调数据集

MLP、Embedding

最佳

目前支持有限

静态量化

激活和权重都是伪量化的

微调数据集

CNN、MLP、Embedding

最佳

通常用于静态量化导致精度不佳的情况,并用于缩小精度差距

请参阅我们的《PyTorch 量化简介》博客文章,以更全面地了解这些量化类型之间的权衡。

量化流程支持#

PyTorch 提供两种量化模式:Eager Mode Quantization 和 FX Graph Mode Quantization。

即时模式量化是一个 Beta 功能。用户需要手动进行融合并指定量化和反量化的发生位置,而且它只支持模块而不支持函数。

FX Graph Mode Quantization 是 PyTorch 中的一个自动化量化框架,目前仍处于原型阶段。它通过增加对函数(functional)的支持和自动化量化过程来改进 Eager Mode Quantization,尽管人们可能需要重构模型以使其与 FX Graph Mode Quantization 兼容(可使用 torch.fx 进行符号跟踪)。请注意,FX Graph Mode Quantization 预计无法在任意模型上工作,因为模型可能无法进行符号跟踪。我们将把它集成到 torchvision 等领域库中,用户将能够使用 FX Graph Mode Quantization 量化类似于受支持领域库中的模型。对于任意模型,我们将提供一般指导方针,但要使其真正工作,用户可能需要熟悉 torch.fx,特别是如何使模型可进行符号跟踪。

建议量化的新用户首先尝试 FX Graph Mode Quantization,如果不起作用,用户可以尝试遵循使用 FX Graph Mode Quantization 的指南或退回到 Eager 模式量化。

下表比较了 Eager Mode Quantization 和 FX Graph Mode Quantization 之间的差异

Eager Mode Quantization

FX Graph Mode Quantization

发布状态

测试版

原型

运算符融合

手动

自动

量化/反量化放置

手动

自动

量化模块

支持

支持

量化函数/Torch 运算符

手动

自动

支持自定义

有限支持

完全支持

量化模式支持

训练后量化:静态、动态、仅权重

量化感知训练:静态

训练后量化:静态、动态、仅权重

量化感知训练:静态

输入/输出模型类型

torch.nn.Module

torch.nn.Module (可能需要一些重构才能使模型与 FX Graph Mode Quantization 兼容)

后端/硬件支持#

硬件

内核库

Eager Mode Quantization

FX Graph Mode Quantization

量化模式支持

服务器 CPU

fbgemm/onednn

支持

所有支持

移动 CPU

qnnpack/xnnpack

服务器 GPU

TensorRT(早期原型)

不支持此功能,因为它需要图形

支持

静态量化

目前,PyTorch 支持以下后端高效运行量化运算符:

  • 支持 AVX2 或更高版本的 x86 CPU(没有 AVX2 时某些操作的实现效率低下),通过由 fbgemmonednn 优化的 x86(详见 RFC

  • ARM CPU(通常在移动/嵌入式设备中找到),通过 qnnpack

  • (早期原型)通过 fx2trt 支持 NVidia GPU,通过 TensorRT(即将开源)

本机 CPU 后端注意事项#

我们通过相同的原生 PyTorch 量化运算符暴露 x86qnnpack,因此需要额外的标志来区分它们。x86qnnpack 的相应实现是根据 PyTorch 构建模式自动选择的,但用户可以选择通过将 torch.backends.quantization.engine 设置为 x86qnnpack 来覆盖此设置。

准备量化模型时,必须确保量化计算所用的 qconfig 和引擎与模型将执行的后端相匹配。qconfig 控制量化过程中使用的观察器类型。qengine 控制在为线性函数和卷积函数以及模块打包权重时是否使用 x86qnnpack 特定打包函数。例如

x86 的默认设置

# set the qconfig for PTQ
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default on x86 CPUs
qconfig = torch.ao.quantization.get_default_qconfig('x86')
# or, set the qconfig for QAT
qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
# set the qengine to control weight packing
torch.backends.quantized.engine = 'x86'

qnnpack 的默认设置

# set the qconfig for PTQ
qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
# or, set the qconfig for QAT
qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
# set the qengine to control weight packing
torch.backends.quantized.engine = 'qnnpack'

操作符支持#

运算符覆盖范围在动态和静态量化之间有所不同,具体如下表所示。请注意,对于 FX 图模式量化,也支持相应的函数。

静态量化

动态量化

nn.Linear
nn.Conv1d/2d/3d
Y
Y
Y
N
nn.LSTM
nn.GRU
N
N
Y
Y
nn.RNNCell
nn.GRUCell
nn.LSTMCell
N
N
N
Y
Y
Y

nn.EmbeddingBag

Y(激活值为 fp32)

Y

nn.Embedding

Y

Y

nn.MultiheadAttention

不支持

不支持

激活

广泛支持

不变,计算保持在 fp32

注意:这将很快根据从 native_backend_config_dict 生成的一些信息进行更新。

量化 API 参考#

量化 API 参考包含量化 API 的文档,例如量化通道、量化张量操作以及支持的量化模块和函数。

量化后端配置#

量化后端配置包含有关如何为各种后端配置量化工作流的文档。

量化精度调试#

量化精度调试包含有关如何调试量化精度的文档。

量化定制#

虽然提供了用于根据观察到的张量数据选择比例因子和偏差的观察器默认实现,但开发人员可以提供自己的量化函数。量化可以有选择地应用于模型的不同部分,或者为模型的不同部分进行不同的配置。

我们还支持对 conv1d()conv2d()conv3d()linear() 进行逐通道量化。

量化工作流通过在模型的模块层次结构中添加(例如,添加观察器作为 .observer 子模块)或替换(例如,将 nn.Conv2d 转换为 nn.quantized.Conv2d)子模块来工作。这意味着模型在整个过程中保持为常规的 nn.Module 基于实例,因此可以与 PyTorch API 的其余部分一起工作。

量化自定义模块 API#

Eager 模式和 FX 图模式量化 API 都为用户提供了一个钩子,用于以自定义方式指定量化模块,并带有用户定义的观察和量化逻辑。用户需要指定

  1. 源 fp32 模块的 Python 类型(存在于模型中)

  2. 观察模块的 Python 类型(由用户提供)。此模块需要定义一个 from_float 函数,该函数定义如何从原始 fp32 模块创建观察模块。

  3. 量化模块的 Python 类型(由用户提供)。此模块需要定义一个 from_observed 函数,该函数定义如何从观察模块创建量化模块。

  4. 描述上述 (1)、(2)、(3) 的配置,传递给量化 API。

框架将执行以下操作

  1. prepare 模块交换期间,它将把 (1) 中指定的每个类型模块转换为 (2) 中指定的类型,使用 (2) 中类的 from_float 函数。

  2. convert 模块交换期间,它将把 (2) 中指定的每个类型模块转换为 (3) 中指定的类型,使用 (3) 中类的 from_observed 函数。

目前,要求 ObservedCustomModule 将具有单个张量输出,并且框架(而不是用户)将在此输出上添加观察器。观察器将作为自定义模块实例的属性存储在 activation_post_process 键下。放宽这些限制可能会在将来完成。

自定义 API 示例

import torch
import torch.ao.nn.quantized as nnq
from torch.ao.quantization import QConfigMapping
import torch.ao.quantization.quantize_fx

# original fp32 module to replace
class CustomModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        return self.linear(x)

# custom observed module, provided by user
class ObservedCustomModule(torch.nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear(x)

    @classmethod
    def from_float(cls, float_module):
        assert hasattr(float_module, 'qconfig')
        observed = cls(float_module.linear)
        observed.qconfig = float_module.qconfig
        return observed

# custom quantized module, provided by user
class StaticQuantCustomModule(torch.nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear(x)

    @classmethod
    def from_observed(cls, observed_module):
        assert hasattr(observed_module, 'qconfig')
        assert hasattr(observed_module, 'activation_post_process')
        observed_module.linear.activation_post_process = \
            observed_module.activation_post_process
        quantized = cls(nnq.Linear.from_float(observed_module.linear))
        return quantized

#
# example API call (Eager mode quantization)
#

m = torch.nn.Sequential(CustomModule()).eval()
prepare_custom_config_dict = {
    "float_to_observed_custom_module_class": {
        CustomModule: ObservedCustomModule
    }
}
convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        ObservedCustomModule: StaticQuantCustomModule
    }
}
m.qconfig = torch.ao.quantization.default_qconfig
mp = torch.ao.quantization.prepare(
    m, prepare_custom_config_dict=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.convert(
    mp, convert_custom_config_dict=convert_custom_config_dict)
#
# example API call (FX graph mode quantization)
#
m = torch.nn.Sequential(CustomModule()).eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig)
prepare_custom_config_dict = {
    "float_to_observed_custom_module_class": {
        "static": {
            CustomModule: ObservedCustomModule,
        }
    }
}
convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        "static": {
            ObservedCustomModule: StaticQuantCustomModule,
        }
    }
}
mp = torch.ao.quantization.quantize_fx.prepare_fx(
    m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.quantize_fx.convert_fx(
    mp, convert_custom_config=convert_custom_config_dict)

最佳实践#

1. 如果您正在使用 x86 后端,我们需要使用 7 位而不是 8 位。请确保您减小 quant\_minquant\_max 的范围,例如,如果 dtypetorch.quint8,请确保将自定义 quant_min 设置为 0,将 quant_max 设置为 127 (255 / 2);如果 dtypetorch.qint8,请确保将自定义 quant_min 设置为 -64 (-128 / 2),将 quant_max 设置为 63 (127 / 2),如果您调用 torch.ao.quantization.get_default_qconfig(backend)torch.ao.quantization.get_default_qat_qconfig(backend) 函数来获取 x86qnnpack 后端的默认 qconfig,我们已经正确设置了这一点

2. 如果选择 onednn 后端,默认 qconfig 映射 torch.ao.quantization.get_default_qconfig_mapping('onednn') 和默认 qconfig torch.ao.quantization.get_default_qconfig('onednn') 中将使用 8 位激活。建议在支持矢量神经网络指令 (VNNI) 的 CPU 上使用。否则,将激活观察器的 reduce_range 设置为 True,以便在不支持 VNNI 的 CPU 上获得更好的精度。

常见问题#

  1. 如何在 GPU 上进行量化推理?

    我们尚未提供官方 GPU 支持,但这是一个积极开发的领域,您可以在此处找到更多信息

  2. 在哪里可以为我的量化模型获得 ONNX 支持?

    如果您在导出模型(使用 torch.onnx 下的 API)时遇到错误,您可以在 PyTorch 仓库中提出问题。在问题标题前加上 [ONNX] 并将问题标记为 module: onnx

    如果您遇到 ONNX Runtime 的问题,请在 GitHub - microsoft/onnxruntime 提交问题。

  3. 如何将量化与 LSTM 结合使用?

    LSTM 通过我们的自定义模块 API 在 eager 模式和 FX 图模式量化中都受支持。示例可在 Eager 模式中找到:pytorch/test_quantized_op.py TestQuantizedOps.test_custom_module_lstm FX 图模式:pytorch/test_quantize_fx.py TestQuantizeFx.test_static_lstm

常见错误#

将非量化张量传入量化内核#

如果您看到类似以下错误

RuntimeError: Could not run 'quantized::some_operator' with arguments from the 'CPU' backend...

这意味着您正在尝试将非量化张量传递给量化内核。一个常见的解决方法是使用 torch.ao.quantization.QuantStub 来量化张量。这需要在 Eager 模式量化中手动完成。一个端到端示例

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)

    def forward(self, x):
        # during the convert step, this will be replaced with a
        # `quantize_per_tensor` call
        x = self.quant(x)
        x = self.conv(x)
        return x

将量化张量传递给非量化内核#

如果您看到类似以下错误

RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend.

这意味着您正在尝试将量化张量传递给非量化内核。一个常见的解决方法是使用 torch.ao.quantization.DeQuantStub 来反量化张量。这需要在 Eager 模式量化中手动完成。一个端到端示例

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.conv1 = torch.nn.Conv2d(1, 1, 1)
        # this module will not be quantized (see `qconfig = None` logic below)
        self.conv2 = torch.nn.Conv2d(1, 1, 1)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        # during the convert step, this will be replaced with a
        # `quantize_per_tensor` call
        x = self.quant(x)
        x = self.conv1(x)
        # during the convert step, this will be replaced with a
        # `dequantize` call
        x = self.dequant(x)
        x = self.conv2(x)
        return x

m = M()
m.qconfig = some_qconfig
# turn off quantization for conv2
m.conv2.qconfig = None

保存和加载量化模型#

在量化模型上调用 torch.load 时,如果看到类似以下错误:

AttributeError: 'LinearPackedParams' object has no attribute '_modules'

这是因为不支持直接使用 torch.savetorch.load 保存和加载量化模型。要保存/加载量化模型,可以使用以下方法:

  1. 保存/加载量化模型的 state_dict

一个例子

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        return x

m = M().eval()
prepare_orig = prepare_fx(m, {'' : default_qconfig})
prepare_orig(torch.rand(5, 5))
quantized_orig = convert_fx(prepare_orig)

# Save/load using state_dict
b = io.BytesIO()
torch.save(quantized_orig.state_dict(), b)

m2 = M().eval()
prepared = prepare_fx(m2, {'' : default_qconfig})
quantized = convert_fx(prepared)
b.seek(0)
quantized.load_state_dict(torch.load(b))
  1. 使用 torch.jit.savetorch.jit.load 保存/加载脚本化的量化模型

一个例子

# Note: using the same model M from previous example
m = M().eval()
prepare_orig = prepare_fx(m, {'' : default_qconfig})
prepare_orig(torch.rand(5, 5))
quantized_orig = convert_fx(prepare_orig)

# save/load using scripted model
scripted = torch.jit.script(quantized_orig)
b = io.BytesIO()
torch.jit.save(scripted, b)
b.seek(0)
scripted_quantized = torch.jit.load(b)

使用 FX Graph 模式量化时出现符号跟踪错误#

符号可追溯性是(原型 - 维护模式) FX 图模式量化的要求,因此如果您将无法符号可追溯的 PyTorch 模型传递给 torch.ao.quantization.prepare_fxtorch.ao.quantization.prepare_qat_fx,我们可能会看到以下错误:

torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

请查看符号跟踪的限制并使用使用 FX 图模式量化的用户指南来解决问题。