• 文档 >
  • 编写自己的量化张量
快捷方式

编写自己的量化张量

torchao 中的量化建立在张量子类的基础上。它们是 torchao 的主要扩展点,用于使用低精度计算提供灵活的推理和训练支持,同时与 torch.compile、autograd 和分布式原语等重要 PyTorch 功能进行组合。

在本教程中,我们将重点介绍与模块交换相比,利用张量子类的好处,并逐步介绍如何使用此方法表达量化的简单示例。

什么是张量子类?

张量子类是简单地继承自 torch.Tensor 的类。它们允许用户在模型中现有操作之间插入自定义计算逻辑,从而使顶级 torch 命名空间中的函数(如 torch.add)能够继续无缝工作。

张量子类方法的明显替代方案是模块交换:例如,只需将模型中的所有 nn.Linear 模块替换为自定义的 Int8QuantizedLinear 模块。与此方法相比,使用张量子类有几个重要好处

  1. 更细粒度的集成点。 模块交换在模块级别拦截计算,因此不适用于依赖 torch 函数或原生模块变体(例如,nn.Linear 的略微修改版本)的模型。相比之下,由于张量子类在函数/操作级别拦截计算,因此只要使用相同的函数/操作,我们就可以对模型进行量化。

  2. 更好的可组合性。 使用模块交换组合多个功能会很笨拙。例如,组合两个现有的 Int8QuantizedLinear 和 DistributedLinear 模块将要求用户创建另一个复制这些功能的线性类。张量子类通过简单地将一个子类包装在另一个子类中来解决此问题。如果外部张量(例如 DTensor)知道内部张量已量化,这也可以提供性能优势,因此可以使用更少的网络和内存带宽执行昂贵的 allgather 操作。

  3. 重用 PyTorch 组件。 使用张量子类表达量化是很自然的,因为量化张量只是具有不同 dtype 的 torch.Tensor。模型结构不变(nn.Linears 仍然是 nn.Linears),因此后续优化过程也可以保持与以前完全相同。


在本教程的其余部分,我们将通过一个示例来介绍如何使用这两种方法表达量化。有关张量子类的更多阅读,请参阅

使用模块交换进行量化

我们首先举一个简单的例子,说明如何使用模块交换实现 int8 对称仅权重量化。所有代码都可以在此示例脚本中找到。我们将使用以下函数将 float32 张量量化为 int8 张量

from typing import Tuple
import torch

def int8_symmetric_quantize(
    fp32_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Symmetrically quantize the torch.float32 tensor into torch.int8.
    Return a 2-tuple of (quantized value, scale).

    input: dimensions=[M, N], dtype=torch.float32
    output: dimensions=[M, N], dtype=torch.int8
    scale: dimensions=[M, 1], dtype=torch.float32
    """
    quant_min = -128
    quant_max = 127
    min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False)
    max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False)
    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
    max_val_pos = torch.max(-min_val_neg, max_val_pos)
    scale = max_val_pos / (float(quant_max - quant_min) / 2)
    scale = scale.view(fp32_tensor.shape[0], -1)
    out = torch.round(fp32_tensor * (1.0 / scale))
    out = torch.clamp(out, quant_min, quant_max).to(torch.int8)
    return out, scale

接下来,我们将创建一个新的 QuantizedLinear 模块,该模块调用此函数以动态量化权重

class QuantizedLinear(torch.nn.Linear):
    """
    Linear module that performs dynamic and symmetric weight-only
    int8 quantization.
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w_int8, scale = int8_symmetric_quantize(self.weight)
        return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t()

    @classmethod
    def from_float(cls, mod: torch.nn.Linear):
        new_linear = cls(mod.in_features, mod.out_features, mod.bias)
        new_linear.weight = mod.weight
        return new_linear

然后,剩下的唯一事情就是将模型中的所有 nn.Linear 模块替换为新的 QuantizedLinear 模块。让我们使用以下玩具模型进行演示

import copy

class ToyModel(torch.nn.Module):
    def __init__(self, m: int, n: int, k: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=False)
        self.linear2 = torch.nn.Linear(n, k, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

float_model = ToyModel(64, 128, 32).cuda()
quantized_model = copy.deepcopy(float_model)

# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model.named_children():
    if type(child) == torch.nn.Linear:
        new_linear = QuantizedLinear.from_float(child)
        setattr(quantized_model, name, new_linear)

验证模型现在是否使用我们的 QuantizedLinear 模块。此模型现在可以使用了!

>>> print(float_model)
ToyModel(
  (linear1): Linear(in_features=64, out_features=128, bias=False)
  (linear2): Linear(in_features=128, out_features=32, bias=False)
)

>>> print(quantized_model)
ToyModel(
  (linear1): QuantizedLinear(in_features=64, out_features=128, bias=False)
  (linear2): QuantizedLinear(in_features=128, out_features=32, bias=False)
)

这种简单方法的一个重要缺点是灵活性。目前这仅适用于原生 PyTorch 模块,但如果模型有略微修改的线性模块(例如,支持分布式训练)怎么办?它也不适用于直接调用线性函数版本 (torch.nn.functional.linear) 的模型。

此外,假设我们想将此功能与通过模块交换实现的分布式功能组合。除了创建另一个结合了这两种功能的模块之外,没有干净的方法可以做到这一点。这些限制可以通过张量子类解决,这是在模型中插入自定义计算(如量化)的更优雅方式。

使用张量子类进行量化

在这里,我们将使用基于 __torch_dispatch__ 的张量子类重新实现上述量化技术。

张量子类(通常使用 __torch_dispatch__)是 PyTorch 中一个非常强大/灵活的扩展点。它们作为扩展点主要有两个目的

  1. 张量子类允许您覆盖(几乎)每个 PyTorch API 的实现,并且在实现其他 PyTorch 产品时大量使用

  2. 张量子类允许您将张量数据与附加元数据耦合。一些示例

    1. [分布式] 张量在秩之间如何分片的元数据(DTensor文档

    2. [量化] 比例/零点元数据(AffineQuantizedTensor

    3. [不规则性] 不规则结构元数据(NestedTensor文档

其他一些对张量子类感兴趣的资源

  1. __torch_dispatch__ 文档(链接

  2. __torch_dispatch__ 是什么(以及为什么)(链接

  3. 使用 __torch_dispatch__ 实现 FlopCounter 和 MemoryTracker 的 Google Colab(链接

话不多说,让我们从为对称量化定义骨干张量子类开始

class Int8SymmetricTensor(torch.Tensor):
    """
    Our subclass represents a tensor that has been quantized to int8
    It will hold two inner tensors:
      int_data: int8[M, N]
      scale: fp32[M, 1]
    """

    @staticmethod
    @torch._dynamo.disable
    def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor):
        return torch.Tensor._make_wrapper_subclass(
            cls,
            int_data.shape,
            strides=int_data.stride(),
            storage_offset=int_data.storage_offset(),
            dtype=scale.dtype,
            device=int_data.device,
        )

    @torch._dynamo.disable
    def __init__(self, int_data: torch.Tensor, scale: torch.Tensor):
        # inner data expected to be quantized already
        assert int_data.dtype is torch.int8
        # we could do more work to support ndim > 2!
        assert int_data.ndim == 2
        assert scale.ndim == 2
        self.int_data = int_data
        self.scale = scale

    def __tensor_flatten__(self) -> Tuple[List[str], Any]:
        """
        Returns a tuple of:
          names of all inner tensor attributes (two in our case)
          any other additional, non-tensor metadata.

        Needed for PT2 support.
        """
        return ["int_data", "scale"], None

    @classmethod
    def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None):
        """
         __tensor_unflatten__ should effectively undo __tensor_flatten__.

        inputs:
          a dict mapping names of inner tensor attributes back to the tensors
          the constant metadata from __tensor_flatten__
        output:
          a new instance of your subclass

        Needed for PT2 support.
        """
        assert extra_metadata is None
        int_data = tensor_data_dict["int_data"]
        scale = tensor_data_dict["scale"]
        return Int8SymmetricTensor(int_data, scale)

    def __repr__(self):
        return f'Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})'

    @staticmethod
    def from_float(float_tensor):
        """
        Actually performs the symmetric quantization.
        In our simple inference example we will quantize weights "ahead-of-time",
        although later in a training example we can quantize/dequantize
        during model execution, inside of our __torch_dispatch__

        input:
          float32 torch.Tensor
        output:
          Int8SymmetricTensor
        """
        int8_tensor, scale = int8_symmetric_quantize(float_tensor)
        return Int8SymmetricTensor(int8_tensor, scale)

    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        """
        Called for each ATen operator that our subclass is passed as an input to.
        We need to define our own implementation for every operator here.
        """
        if kwargs is None:
            kwargs = {}
        if func not in op_implementations_dict:
            raise AssertionError(f'Int8SymmetricTensor does not yet support op: {str(func)}')
        return op_implementations_dict[func](func, *args, **kwargs)


# Convenience function for registering our own implementation
# to every ATen operator in PyTorch
op_implementations_dict = {}
def register_op(ops: List[torch._ops.OpOverload]):
    def impl_decorator(op_impl):
        global op_implementations_dict
        for op in ops:
            op_implementations_dict[op] = op_impl
        return op_impl

    return impl_decorator

在上面的代码中,我们做了几件事

  1. 定义了一个基本的“包装器”张量子类——它实际上是一个容器对象,包含一些内部数据(特别是对应于我们的 int8 数据和比例的两个张量)

  2. 定义了一个 __torch_dispatch__ 实现,当模型对我们的任何子类输入调用任何 ATen 运算符时,都会调用它

  3. (对于 PT2 支持)定义了 __tensor_flatten__/__tensor_unflatten__ 方法。这是我们子类要与 torch.compile 一起工作所需的最大要求之一(稍后会详细介绍)。它有效地告诉 torch.compile 如何将我们的子类“解糖”为它的内部组件。

  4. (对于 PT2 支持)为两个构造函数方法(__new____init__)添加了 torch._dynamo.disable 装饰器(稍后会详细介绍)。

我们应该实现哪些运算符?

PyTorch 有一个相当大的运算符接口。我们不应该尝试让新的张量子类达到 100% 的覆盖率,而应该专注于我们上面玩具模型所需的运算符。

那么,我们的模型中调用了哪些运算符呢,这样我们才能知道首先要实现什么?暴力方法是重复运行模型,查看子类中出现哪些运算符错误。更优雅的方法是记录模型在执行期间看到的所有运算符。这可以通过另一个 LoggingTensor 子类来实现,如此示例所示。

让我们实现下面必要的运算符

from torch.utils._python_dispatch import return_and_correct_aliasing

@register_op([torch.ops.aten.mm.default])
def int8_mm(func, x, weight):
    assert isinstance(weight, Int8SymmetricTensor), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!"
    return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale

@register_op([
    torch.ops.aten.detach.default,
    torch.ops.aten.t.default,
])
def int8_view_ops(func, *args, **kwargs):
    assert isinstance(args[0], Int8SymmetricTensor)
    out_data = func(args[0].int_data, *args[1:], **kwargs)
    out_scale = func(args[0].scale, *args[1:], **kwargs)
    out = Int8SymmetricTensor(out_data, out_scale)
    return return_and_correct_aliasing(func, args, kwargs, out)

你会很快注意到一件事:我们的模型本身由几个线性层组成,但我们看到一些操作,如 aten.taten.mm 击中了我们的子类。一些背景

  • 我们有许多 C++ 中存在的运算符分解,它们运行在张量子类“之上”。linear 就是其中之一(分解位于此处

  • 分解在某种意义上是好的,因为它们缩小了作为子类作者必须实现的 API 大小。但如果你宁愿覆盖“更高级别”的运算符而不是其分解中的底层操作,它们可能会很痛苦。

  • 如果你更喜欢在更高级别覆盖某些操作(如 Linear),你可以使用 __torch_function__ (示例) 来做到这一点。值得注意的是,如果你想要自动梯度支持,那么你在 __torch_function__ 层执行的任何覆盖都需要以可微分的方式编写,而你在 __torch_dispatch__ 中执行的任何覆盖将自动可微分。

我们的实现中有一些值得指出的细微差别

  1. 你会注意到我们不再需要在 mm 实现中转置权重/比例。那是因为转置在到达 aten.mm 操作之前“已经发生”了。

  2. 我们的 aten.mm 实现返回张量子类输出。从这个意义上说,我们量化子类的“传播”以矩阵乘法结束。这映射到我们的权重是低精度的,但我们需要以高精度执行矩阵乘法本身的事实。通常,子类作者可以自由选择他们的子类对哪些操作进行传播或不传播。如果你希望模型中的每个函数都被量化(包括所有逐点和缩减操作),你可以编写子类实现来量化每个操作的输出并始终返回一个子类。

  3. 我们能够对 4 个视图操作重用相同的实现。通常,许多操作可能使用一个非常通用的实现:解包任何子类输入,对内部张量运行底层操作符,并将输出包装回子类。

    • 然而,你是否总是可以重用一个实现,取决于你正在尝试做什么。例如,我们在子类上实现了 transpose(dim0, dim1),通过对内部数据和内部比例张量调用相同的转置。如果我们的比例和数据张量具有不同的维度数量,这将不起作用,因此在这种情况下,转置将需要自定义实现。

比较输出

完成所有这些之后,让我们用两种量化版本运行模型,并确认它们给出相同的输出!

float_model = ToyModel(64, 128, 32).cuda()
quantized_model_module_swap = copy.deepcopy(float_model)
quantized_model_subclass = copy.deepcopy(float_model)

# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model_module_swap.named_children():
    if type(child) == torch.nn.Linear:
        new_linear = QuantizedLinear.from_float(child)
        setattr(quantized_model_module_swap, name, new_linear)

# Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses
for name, child in quantized_model_subclass.named_children():
    if type(child) == torch.nn.Linear:
        subclass_param = Int8SymmetricTensor.from_float(child.weight)
        child.weight = torch.nn.Parameter(subclass_param, requires_grad=True)

with torch.no_grad():
    x = torch.randn(64, 64, 64, device='cuda')
    out_module_swap = quantized_model_module_swap(x)
    out = quantized_model_subclass(x)
    print(torch.allclose(out, out_module_swap))  # prints True

    # We can also use torch.compile to fuse some of our quantized logic
    out_compiled = torch.compile(quantized_model_subclass)(x)
    print(torch.allclose(out, out_compiled))  # prints True

下一步

在本教程中,我们演示了如何构建一个简单的量化张量子类。这是本系列两个教程的第一部分。下一篇文章将讨论如何向张量子类添加更高级的功能,例如使其可训练、与 DTensors 组合以及添加张量并行支持。有关 torchao 中 AffineQuantizedTensor 如何使用张量子类构建的更详细示例,另请参阅此示例

如果您在实现子类时有任何疑问,请随时在此处提交问题。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源