静态量化¶
静态量化是指在推理或生成过程中,对所有输入使用固定的量化范围。与动态量化不同,动态量化会为每个新的输入批次动态计算新的量化范围,而静态量化通常会带来更高效的计算,但可能会以牺牲量化精度为代价,因为我们无法实时适应输入分布的变化。
在静态量化中,这个固定的量化范围通常在量化模型之前,通过类似输入进行校准。在校准阶段,我们首先将观察器插入模型中,以“观察”要量化的输入的分布,并使用该分布来决定最终量化模型时使用的尺度和零点。
在本教程中,我们将通过一个示例来演示如何在 torchao 中实现这一点。所有代码都可以在此示例脚本中找到。让我们从我们的玩具线性模型开始
import copy
import torch
class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, k, bias=False)
self.linear2 = torch.nn.Linear(k, n, bias=False)
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
return (
torch.randn(
batch_size, self.linear1.in_features, dtype=dtype, device=device
),
)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
dtype = torch.bfloat16
m = ToyLinearModel().eval().to(dtype).to("cuda")
m = torch.compile(m, mode="max-autotune")
校准阶段¶
torchao 提供了一个简单的观察器实现,AffineQuantizedMinMaxObserver,它记录了在校准阶段流经观察器的最小值和最大值。欢迎用户实现自己所需的、更高级的观察技术,例如那些依赖移动平均或直方图的技术,这些技术未来可能会添加到 torchao 中。
from torchao.quantization.granularity import PerAxis, PerTensor
from torchao.quantization.observer import AffineQuantizedMinMaxObserver
from torchao.quantization.quant_primitives import MappingType
# per tensor input activation asymmetric quantization
act_obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float32,
zero_point_dtype=torch.float32,
)
# per channel weight asymmetric quantization
weight_obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float32,
zero_point_dtype=torch.float32,
)
接下来,我们定义我们的观测线性层,它将替换我们的torch.nn.Linear。这是一个高精度(例如 fp32)线性模块,其中插入了上述观察器,用于在校准期间记录输入激活和权重值。
import torch.nn.functional as F
class ObservedLinear(torch.nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
act_obs: torch.nn.Module,
weight_obs: torch.nn.Module,
bias: bool = True,
device=None,
dtype=None,
):
super().__init__(in_features, out_features, bias, device, dtype)
self.act_obs = act_obs
self.weight_obs = weight_obs
def forward(self, input: torch.Tensor):
observed_input = self.act_obs(input)
observed_weight = self.weight_obs(self.weight)
return F.linear(observed_input, observed_weight, self.bias)
@classmethod
def from_float(cls, float_linear, act_obs, weight_obs):
observed_linear = cls(
float_linear.in_features,
float_linear.out_features,
act_obs,
weight_obs,
False,
device=float_linear.weight.device,
dtype=float_linear.weight.dtype,
)
observed_linear.weight = float_linear.weight
observed_linear.bias = float_linear.bias
return observed_linear
要将这些观察器实际插入到我们的玩具模型中
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)
def insert_observers_(model, act_obs, weight_obs):
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
def replacement_fn(m):
copied_act_obs = copy.deepcopy(act_obs)
copied_weight_obs = copy.deepcopy(weight_obs)
return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs)
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)
insert_observers_(m, act_obs, weight_obs)
现在我们准备校准模型,这将用校准期间记录的统计数据填充我们插入的观察器。我们可以简单地通过向我们的“观测”模型提供一些示例输入来完成此操作
for _ in range(10):
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
m(*example_inputs)
量化阶段¶
实际量化模型有多种方法。这里我们介绍一种更简单的替代方案,即定义一个 QuantizedLinear 类,我们将用它替换我们的 ObservedLinear。定义这个新类并不是严格必要的。对于只使用现有 torch.nn.Linear 的替代方法,请参阅完整的示例脚本。
from torchao.dtypes import to_affine_quantized_intx_static
class QuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
act_obs: torch.nn.Module,
weight_obs: torch.nn.Module,
weight: torch.Tensor,
bias: torch.Tensor,
target_dtype: torch.dtype,
):
super().__init__()
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
assert weight.dim() == 2
block_size = (1, weight.shape[1])
self.target_dtype = target_dtype
self.bias = bias
self.qweight = to_affine_quantized_intx_static(
weight, weight_scale, weight_zero_point, block_size, self.target_dtype
)
def forward(self, input: torch.Tensor):
block_size = input.shape
qinput = to_affine_quantized_intx_static(
input,
self.act_scale,
self.act_zero_point,
block_size,
self.target_dtype,
)
return F.linear(qinput, self.qweight, self.bias)
@classmethod
def from_observed(cls, observed_linear, target_dtype):
quantized_linear = cls(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.act_obs,
observed_linear.weight_obs,
observed_linear.weight,
observed_linear.bias,
target_dtype,
)
return quantized_linear
这个线性类在开始时计算输入激活和权重的尺度和零点,从而有效地固定了未来前向调用的量化范围。现在,要使用这个线性类实际量化模型,我们可以定义以下配置并将其传递给 torchao 的主 quantize_ API
from dataclasses import dataclass
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
@dataclass
class StaticQuantConfig(AOBaseConfig):
target_dtype: torch.dtype
@register_quantize_module_handler(StaticQuantConfig)
def _apply_static_quant(
module: torch.nn.Module,
config: StaticQuantConfig,
):
"""
Define a transformation associated with `StaticQuantConfig`.
This is called by `quantize_`, not by the user directly.
"""
return QuantizedLinear.from_observed(module, config.target_dtype)
# filter function to identify which modules to swap
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)
# perform static quantization
quantize_(m, StaticQuantConfig(torch.uint8), is_observed_linear)
现在,我们将看到模型中的线性层已替换为我们的 QuantizedLinear 类,并具有固定的输入激活尺度和固定的量化权重
>>> m
OptimizedModule(
(_orig_mod): ToyLinearModel(
(linear1): QuantizedLinear()
(linear2): QuantizedLinear()
)
)
>>> m.linear1.act_scale
tensor([0.0237], device='cuda:0')
>>> m.linear1.qweight
AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=tensor([[142, 31, 42, ..., 113, 157, 57],
[ 59, 160, 70, ..., 23, 150, 67],
[ 44, 49, 241, ..., 238, 69, 235],
...,
[228, 255, 201, ..., 114, 236, 73],
[ 50, 88, 83, ..., 109, 209, 92],
[184, 141, 35, ..., 224, 110, 66]], device='cuda:0',
dtype=torch.uint8)... , scale=tensor([0.0009, 0.0010, 0.0009, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0010,
0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
0.0010, 0.0010, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0010,
0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010, 0.0009,
0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010,
0.0009, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010, 0.0010, 0.0009, 0.0009,
0.0010, 0.0009, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009,
0.0010], device='cuda:0')... , zero_point=tensor([130., 128., 122., 130., 132., 128., 125., 130., 126., 128., 129., 126.,
128., 128., 128., 128., 129., 127., 130., 125., 128., 133., 126., 126.,
128., 124., 127., 128., 128., 128., 129., 124., 126., 133., 129., 127.,
126., 124., 130., 126., 127., 129., 124., 125., 127., 130., 128., 132.,
128., 129., 128., 129., 131., 132., 127., 135., 126., 130., 124., 136.,
131., 124., 130., 129.], device='cuda:0')... , _layout=PlainLayout()), block_size=(1, 64), shape=torch.Size([64, 64]), device=cuda:0, dtype=torch.bfloat16, requires_grad=False)
在本教程中,我们通过一个基本示例介绍了如何在 torchao 中执行整数静态量化。我们还提供了一个如何在 float8 中执行相同静态量化的示例。更多详细信息请参阅完整的示例脚本!