快捷方式

量化概述

首先,我们想展示 torchao 的堆栈

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
    Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor
---------------------------------------------------------------------------------------------
  Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
            Basic dtypes: uint1-uint7, int1-int8, float3-float8

任何量化算法都将使用上述堆栈中的某些组件,例如,每行动态 float8 激活和 float8 权重量化(默认首选项)使用

基本数据类型

dtype 是一个被过度使用的术语,我们所说的基本数据类型是指不需要任何额外元数据就有意义的数据类型(例如,当人们调用 torch.empty(.., dtype) 时有意义)。更多细节请参阅 此帖

无论我们进行何种量化,最终都将使用一些低精度数据类型来表示量化数据或量化参数。与 torchao 相关的低精度数据类型是:

  • PyTorch 2.3 及更高版本中可用的 torch.uint1torch.uint7

  • PyTorch 2.6 及更高版本中可用的 torch.int1torch.int7

  • torch.float4_e2m1fn_x2torch.float8_e4m3fntorch.float8_e4m3fnuztorch.float8_e5m2torch.float8_e5m2fnuztorch.float8_e8m0fnu

在实际实现方面,uint1uint7int1int7 只是占位符,没有实际实现(即,对于这些数据类型的 PyTorch Tensor,算子不起作用)。添加了这些数据类型的示例 PR 可以在 这里 找到。浮点数据类型是我们称之为“Shell Dtypes”的数据类型,它们具有有限的算子支持。

更多详情请参阅 官方 PyTorch 数据类型文档

注意

诸如 mxfp8、mxfp4、nvfp4 之类的派生数据类型使用这些基本数据类型实现,例如,mxfp4 使用 torch.float8_e8m0fnu 作为 scale,并使用 torch.float4_e2m1fn_x2 作为 4 位数据。

量化原始算子

量化原始算子是指用于在低精度量化张量和高精度张量之间进行转换的算子。我们主要有以下量化原始算子:

  • choose_qparams 算子:根据原始张量选择量化参数,通常用于动态量化,例如,仿射量化的 scale 和 zero_point

  • quantize 算子:根据量化参数,将原始高精度张量量化为前一节中提到的数据类型的低精度张量

  • dequantize 算子:根据量化参数,将低精度张量反量化为高精度张量

为了适应特定用例,上述算子可能会有所变化,例如,对于静态量化,我们可能有一个 choose_qparams_affine_with_min_max,它会根据观察过程中得出的 min/max 值来选择量化参数。

对于不同的内核库,我们可以在 torchao 中使用算子的多个版本,例如,将 bfloat16 张量量化为原始 float8 张量并获取 scale:_choose_scale_float8_quantize_affine_float8 用于 torchao 实现,以及来自 fbgemm 库的 torch.ops.triton.quantize_fp8_row

高效内核

我们还将提供与低精度张量一起工作的高效内核,例如:

  • torch.ops.fbgemm.f8f8bf16_rowwise (fbgemm 库中的行式 float8 激活和 float8 权重矩阵乘法内核)

  • torch._scaled_mm (PyTorch 中用于行式和张量式计算的 float8 激活和 float8 权重矩阵乘法内核)

  • int_matmul:接受两个 int8 张量并输出一个 int32 张量

  • int_scaled_matmul:执行矩阵乘法并对结果应用 scale。

注意

我们还可以依赖 torch.compile 生成内核(通过 triton),例如,当前的 int8 仅权重量化 内核 仅依靠 torch.compile 来加速。在这种情况下,没有与量化类型相对应的自定义手写“高效内核”。

量化张量(派生数据类型和打包格式)

在基本数据类型、量化原始算子和高效内核的基础上,我们可以将它们组合起来构建一个量化(低精度)张量,通过继承 torch.Tensor 来实现。这个张量可以由一个高精度张量和一些参数来构造,这些参数可以配置用户想要的特定量化。我们也可以称之为派生数据类型,因为它可以由基本数据类型的张量和一些额外的元数据(如 scale)来表示。

量化张量的另一个维度是打包格式,即量化的原始数据在内存中的布局方式。例如,对于 int4,我们可以将两个元素打包到一个 uint8 值中,或者人们可以进行一些预混/交换操作,以使格式对于内存操作(从内存加载到寄存器)和计算更有效。

所以,总的来说,我们通过派生数据类型和打包格式来构造张量子类。

TorchAO 中的张量子类

张量

派生数据类型

打包格式

支持

Float8Tensor

缩放的 float8

普通(无需打包)

float8 激活 + float8 权重动态量化和 float8 仅权重量化

Int4Tensor

缩放的 int4

普通(将 2 个相邻的 int4 打包到一个 int8 值中)

int4 仅权重量化

Int4PreshuffledTensor

缩放的 int4

预混(用于优化加载的特殊格式)

float8 激活 + int4 权重动态量化和 int4 仅权重量化

注意

我们没有粒度特定的张量子类,即没有 Float8RowwiseTensor 或 Float8BlockwiseTensor,所有粒度都在同一个张量中实现。我们通常使用一个通用的 block_size 属性来区分不同的粒度,并且每个张量只允许支持所有可能粒度选项的一个子集。

注意

我们也不在名称中使用动态激活,因为我们讨论的是权重张量对象,在张量子类名称中包含激活信息会造成混淆。但是,我们在同一个线性函数实现中同时实现了仅权重和动态激活量化,而无需依赖额外的抽象。这使得相关的量化操作(激活和权重的量化)保持在同一个张量子类中。

在如何量化张量方面,大多数张量使用仿射量化,这意味着低精度张量通过仿射映射从高精度张量量化,即:low_precision_val = high_precision_val / scale + zero_point,其中 scalezero_point 是可以通过量化原始算子或通过某些优化过程计算出的量化参数。另一种常见的量化类型,尤其对于较低的比特宽度(例如低于 4 位)是基于码本/查找表的量化,其中原始量化数据是我们可以用来查找存储每个索引对应值的 codebook 的索引。一种获取码本和用于码本量化的原始量化数据的方法是 K-means 聚类。

量化算法/流程

在堆栈的顶部是最终的量化算法和量化流程。传统上,我们有仅权重量化、动态量化和静态量化,但现在我们也看到了更多类型的量化出现。

出于演示目的,假设在前面的步骤之后,我们定义了 Float8TensorFloat8Tensor.from_hp 接受一个高精度浮点张量和一个 target_dtype(例如 torch.float8_e4m3fn)并将其转换为 Float8Tensor

注意:以下内容均用于解释概念,有关我们提供的工具和示例的更详细介绍,请参阅 贡献者指南

仅权重量化

这是最简单的量化形式,并且易于将仅权重量化应用于模型,特别是由于我们拥有量化张量。我们所需要做的就是:

linear_module.weight = torch.nn.Parameter(Float8Tensor.from_hp(linear_module.weight, ...), requires_grad=False))

将以上方法应用于模型中的所有线性模块,我们将获得一个仅权重量化模型。

动态激活和权重量化

以前称为“动态量化”,但它意味着我们在运行时动态地量化激活,并且也量化权重。与仅权重量化相比,主要问题是如何将量化应用于激活。在 torchao 中,我们传递激活的量化关键字参数,当需要时(例如在线性层中),这些关键字参数将被应用于激活。

activation_dtype = torch.float8_e4m3fn
activation_granularity = PerRow()
# define kwargs for float8 activation quantization
act_quant_kwargs = QuantizeTensorToFloat8Kwargs(
  activation_dtype,
  activation_granularity,
)
weight_dtype = torch.float8_e4m3fn
weight_granularity = PerRow()
quantized_weight = Float8Tensor.from_hp(linear_module.weight, float8_dtype=weight_dtype, granularity=weight_granularity, act_quant_kwargs=act_quant_kwargs)
linear_module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False))

静态激活量化和权重量化

我们暂时跳过说明,因为我们还没有看到许多使用基于张量子类的静态量化流程的用例。我们建议查看 PT2 导出量化流程 以进行静态量化。

其他量化流程

对于不属于以上任何一种的量化流程/算法,我们也打算提供常见模式的示例。例如,GPTQ 类量化流程,它被 Autoround 采用,它使用 MultiTensor 和模块钩子来优化模块。

如果您正在开发一种新的量化算法/流程,并且不确定如何以 PyTorch 原生方式实现它,请随时提交一个 issue 来描述您的算法是如何工作的,我们可以帮助您提供实现细节方面的建议。

训练

上述流程主要侧重于推理,但低比特数据类型张量也可用于训练。

float8 训练的用户文档可以在 这里 找到,微调文档可以在 这里 找到。

量化感知训练

TorchAO 也通过 quantize_ API 支持 量化感知训练

低比特优化器

我们支持 低比特优化器,它们实现了特定类型的 4 位、8 位和 float8 量化,并且可以与 FSDP 组合(使用查找表量化)。

量化训练

我们在 main/torchao/prototype/quantized_training 中有量化训练原型,并且也可以扩展现有的张量子类以支持训练。初步启用正在进行中,但仍需要大量后续工作,包括使其适用于不同的内核等。

您还可以查看关于 量化训练 的教程,该教程介绍了如何使 dtype 张量子类可训练。

案例研究:torchao 中的 float8 动态激活和 float8 权重量化是如何工作的?

为了将所有内容连接起来,以下是 torchao 中 float8 动态激活和 float8 权重量化的更详细的演练(默认内核首选项,在 H100 上,如果安装了 fbgemm_gpu_genai 库)

量化流程:quantize_(model, Float8DynamicActivationFloat8WeightConfig())
  • 发生的情况:linear.weight = torch.nn.Parameter(Float8Tensor.from_hp(linear.weight), requires_grad=False)

  • 量化原始算子:torch.ops.triton.quantize_fp8_row

  • 量化张量将是 Float8Tensor,一个具有缩放 float8 派生数据类型的量化张量。

模型执行期间:model(input)
  • torch.ops.fbgemm.f8f8bf16_rowwise 在输入、原始 float8 权重和 scale 上被调用。

量化期间

首先,我们从 API 调用开始:quantize_(model, Float8DynamicActivationFloat8WeightConfig())。它的作用是将模型中 nn.Linear 模块的权重转换为 Float8Tensor,采用普通打包格式,无需打包,因为我们有 torch.float8_e4m3fn,它可以直接表示量化的 float8 原始数据而无需额外操作。

  • quantize_:量化权重的模型级 API,通过应用用户(第二个参数)的配置来实现。

  • Float8DynamicActivationFloat8WeightConfig:float8 动态激活和 float8 权重量化的配置 * 调用量化原始算子 torch.ops.triton.quantize_fp8_row 将 bfloat16 张量量化为 float8 原始张量并获取 scale。

模型执行期间

当我们运行量化模型 model(inputs) 时,我们将通过 nn.Linear 的函数式线性算子。

return F.linear(input, weight, bias)

其中输入是 bfloat16 张量,权重是 Float8Tensor。它会调用 Float8Tensor 子类的 __torch_function__,当输入之一是 Float8Tensor 时,最终会进入 F.linear 的实现。

@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
    input_tensor, weight_tensor, bias = (
      args[0],
      args[1],
      args[2] if len(args) > 2 else None,
    )
    # quantizing activation, if `act_quant_kwargs` is specified
    if act_quant_kwargs is not None:
      input_tensor = _choose_quant_func_and_quantize_tensor(
          input_tensor, act_quant_kwargs
      )

    # omitting kernel_preference related code
    # granularity checks, let's say we are doing rowwise quant
    # both input_tensor and weight_tensor will now be Float8Tensor
    xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
    wq = weight_tensor.qdata.contiguous()
    x_scale = input_tensor.scale
    w_scale = weight_tensor.scale
    res = torch.ops.fbgemm.f8f8bf16_rowwise(
       xq,
       wq,
       x_scale,
       w_scale,
    ).reshape(out_shape)
    return res

该函数首先将输入量化为 Float8Tensor,然后从输入张量和权重张量中获取原始 float 张量和 scale:t.qdatat.scale,并调用 fbgemm 内核进行 float8 动态量化的矩阵乘法:torch.ops.fbgemm.f8f8bf16_rowwise

保存/加载期间

由于 Float8Tensor 权重仍然是 torch.Tensor,因此保存/加载与原始高精度浮点模型的工作方式相同。更多详情请参阅 序列化文档

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源