快捷方式

量化操作

本文档概述了如何利用量化操作在 XLA 设备上启用量化。

XLA 量化操作为量化操作(例如,块状 int4 量化矩阵乘法)提供了高级抽象。这些操作类似于 CUDA 生态系统中的量化 CUDA 内核(示例),在 XLA 框架内提供了类似的功能和性能优势。

注意: 当前此功能被classified为实验性功能。其 API 规范将在下一个(2.5)版本中发生变化。

如何使用:

XLA 量化操作可以作为 torch op 使用,也可以作为包装 torch.optorch.nn.Module 使用。这两种选项为模型开发人员提供了灵活性,可以根据自身需求选择最佳方式将 XLA 量化操作集成到其解决方案中。

torch opnn.Module 都与 torch.compile( backend='openxla') 兼容。

在模型代码中调用 XLA 量化操作

用户可以像调用其他常规 PyTorch 操作一样调用 XLA 量化操作。这为将 XLA 量化操作集成到其应用程序中提供了最大的灵活性。量化操作在 eager 模式和 Dynamo 中均可正常工作,支持常规的 PyTorch CPU 张量和 XLA 张量。

注意 请检查量化操作的文档字符串以了解量化权重的布局。

import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_quantized_matmul

N_INPUT_FEATURES=10
N_OUTPUT_FEATURES=20
x = torch.randn((3, N_INPUT_FEATURES), dtype=torch.bfloat16)
w_int = torch.randint(-128, 127, (N_OUTPUT_FEATURES, N_INPUT_FEATURES), dtype=torch.int8)
scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16)

# Call with torch CPU tensor (For debugging purpose)
matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler)

device = torch_xla.device()
x_xla = x.to(device)
w_int_xla = w_int.to(device)
scaler_xla = scaler.to(device)

# Call with XLA Tensor to run on XLA device
matmul_output_xla = torch.ops.xla.quantized_matmul(x_xla, w_int_xla, scaler_xla)

# Use with torch.compile(backend='openxla')
def f(x, w, s):
  return torch.ops.xla.quantized_matmul(x, w, s)

f_dynamo = torch.compile(f, backend="openxla")
dynamo_out_xla = f_dynamo(x_xla, w_int_xla, scaler_xla)

模型开发者通常会在其模型代码中将量化操作封装到自定义 nn.Module

class MyQLinearForXLABackend(torch.nn.Module):
  def __init__(self):
    self.weight = ...
    self.scaler = ...

  def load_weight(self, w, scaler):
    # Load quantized Linear weights
    # Customized way to preprocess the weights
    ...
    self.weight = processed_w
    self.scaler = processed_scaler


  def forward(self, x):
    # Do some random stuff with x
    ...
    matmul_output = torch.ops.xla.quantized_matmul(x, self.weight, self.scaler)
    # Do some random stuff with matmul_output
    ...

模块替换

或者,用户也可以使用包装 XLA 量化操作的 nn.Module,并在模型代码中进行模块替换。

orig_model = MyModel()
# Quantize the model and get quantized weights
q_weights = quantize(orig_model)
# Process the quantized weight to the format that XLA quantized op expects.
q_weights_for_xla = process_for_xla(q_weights)

# Do module swap
q_linear = XlaQuantizedLinear(self.linear.in_features,
                              self.linear.out_features)
q_linear.load_quantized_weight(q_weights_for_xla)
orig_model.linear = q_linear

支持的量化操作:

矩阵乘法

权重 激活 数据类型 支持
每通道(对称/非对称) W8A16
每通道(对称/非对称) 不适用 W8A8
每通道 每 token W8A8
每通道 每 token W4A8
块状(对称/非对称) 不适用 W8A16
块状(对称/非对称) 不适用 W8A16
块状 每 token W8A8
块状 每 token W4A8

注意 W[X]A[Y] 指的是 X 位数的权重,Y 位数的激活。如果 X/Y 为 4 或 8,则表示 int4/8。16 表示 bfloat16 格式。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源