• 文档 >
  • 与 VLLM 的集成:架构和使用指南
快捷方式

与 VLLM 集成:架构和使用指南

本教程全面概述了 TorchAO 如何与 VLLM 集成,以及为使新技术端到端工作需要实现什么。

配置系统

1. HuggingFace 模型配置

TorchAO 量化通过模型的 config.json 文件配置

{
  "model_type": "llama",
  "quant_type": {
    "default": {
      "_type": "Int4WeightOnlyConfig",
      "_data": {
        "group_size": 128,
        "use_hqq": true
      }
    }
  }
}

2. TorchAO 配置类

所有量化方法都继承自 AOBaseConfig

from torchao.core.config import AOBaseConfig
from torchao.quantization import Int4WeightOnlyConfig

# Example configuration
config = Int4WeightOnlyConfig(
    group_size=128,
    use_hqq=True,
)
assert isinstance(config, AOBaseConfig)

注意

所有量化配置都继承自 torchao.core.config.AOBaseConfig,该类提供序列化和验证功能。

3. 模块级配置

对于精细控制,请使用 ModuleFqnToConfig

from torchao.quantization import ModuleFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig

config = ModuleFqnToConfig({
    "model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64),
    "model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64),
    "model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(),
    "_default": Int4WeightOnlyConfig(group_size=128)  # Default for other modules
})

使用示例

1. 使用 HuggingFace 集成量化模型

from transformers import TorchAoConfig, AutoModelForCausalLM
from torchao.quantization import Int4WeightOnlyConfig

# Create quantization configuration
quantization_config = TorchAoConfig(
    quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True)
)

# Load and automatically quantize the model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    torch_dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

# Save quantized model (see Serialization section below for safe_serialization details)
model.push_to_hub("your-username/Llama-3.2-1B-int4", safe_serialization=False)

另请参阅

有关量化配置的更多信息,请参阅 torchao.quantization.Int4WeightOnlyConfigtorchao.quantization.Int8WeightOnlyConfig

2. 使用 VLLM 提供服务

# Start VLLM server with TorchAO quantized model
vllm serve your-username/Llama-3.2-1B-int4 \
    --quantization torchao \
    --dtype bfloat16 \

向 VLLM 添加新的量化方法

VLLM 兼容的最低要求

要使新的 TorchAO 量化方法与 VLLM 一起工作,您需要实现支持张量并行的最小张量子类操作。VLLM 使用 narrow()copy_() 将状态字典中加载的主机 CPU 数据移动到设备,这些操作需要特定的 aten 操作。

为什么是这些?

VLLM 的张量并行工作原理如下:

  1. narrow() - 跨不同维度切片权重张量

  2. 分片 - 将张量块分布到多个 GPU 上

  3. copy_() - 在设备之间移动张量数据

  4. detach()

一个有用的模式是 _apply_fn_to_data,它将给定函数应用于类中所有具有 Tensor 类型的属性。下面是一个通用的实现,应该适用于大多数子类。我们在 torchao 代码库中大量使用这种模式

def _apply_fn_to_data(self, fn: Callable):
    """Applies a fn to all tensor components stored on this class"""
    tensor_names, ctx = self.__tensor_flatten__()

    # Apply the function to each tensor component
    new_tensors = {}
    for name in tensor_names:
        new_tensors[name] = fn(getattr(self, name))

    return self.__class__.__tensor_unflatten__(
        new_tensors,
        ctx,
        None,  # outer_size parameter
        None,  # outer_stride parameter
    )

添加新量化方法的逐步指南

1. 创建您的 Tensor 子类

注意

有关张量子类及其设计原则的更多详细信息,请参阅 什么是张量子类? 文档。

from torchao.core.config import AOBaseConfig
from torchao.utils import TorchAOBaseTensor

@dataclass
class MyNewQuantConfig(AOBaseConfig):
    """Configuration for your new quantization method"""
    bits: int = 8
    VERSION: ClassVar[int] = 1

class MyQuantizedTensor(TorchAOBaseTensor):
    """Example based on FbgemmFp8Tensor - stores quantized data + scale"""

    tensor_data_attrs = ["quantized_data", "scale"]
    tensor_attributes = ["dtype"]

    def __new__(cls, quantized_data, scale, dtype):
        shape = quantized_data.shape
        return torch.Tensor._make_wrapper_subclass(
            cls, shape, device=quantized_data.device, dtype=dtype, requires_grad=False
        )

    def __init__(self, quantized_data, scale, dtype):
        self.quantized_data = quantized_data
        self.scale = scale

    def __tensor_flatten__(self) -> Tuple[List[str], List]:
        """Serialize tensor subclass into plain tensors and metadata"""
        return self.tensor_data_attrs, [
            getattr(self, attr) for attr in self.tensor_attributes
        ]

    @classmethod
    def __tensor_unflatten__(
        cls,
        tensor_data_dict: Dict[str, torch.Tensor],
        tensor_attributes: List,
        outer_size: Optional[torch.Size],
        outer_stride: Optional[Tuple],
    ) -> "MyQuantizedTensor":
        """Reconstruct tensor subclass from serialized data"""
        return cls(
            *[tensor_data_dict[name] for name in cls.tensor_data_attrs],
            *tensor_attributes,
        )

2. 实现所需的 VLLM 操作

from torch.utils._python_dispatch import return_and_correct_aliasing

@MyQuantizedTensor.implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
    return return_and_correct_aliasing(
        func, args, kwargs, args[0]._apply_fn_to_data(func)
    )

@MyQuantizedTensor.implements([aten._to_copy.default])
def _(func, types, args, kwargs):
    return return_and_correct_aliasing(
        func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
    )

@MyQuantizedTensor.implements([aten.slice.Tensor])
def _(func, types, args, kwargs):
    self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
    if dim == 0 or dim == 1:
        # NOTE the slicing here will likely be different for different quant techniques
        return return_and_correct_aliasing(
            func, args, kwargs,
            args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
        )
    else:
        raise NotImplementedError(f"Slicing along dim={dim} not supported")

3. 在 TorchAO 的量化系统中注册

from torchao.quantization.transform_module import register_quantize_module_handler

@register_quantize_module_handler(MyNewQuantConfig)
def _my_quant_transform(module: torch.nn.Module, config: MyNewQuantConfig):
    """Transform function that applies your quantization to a module"""
    weight = module.weight

    # Your quantization logic here
    quantized_weight = my_quantization_function(weight, config)

    # Replace the weight with your quantized tensor
    module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
    return module

重要提示

装饰器 torchao.quantization.transform_module.register_quantize_module_handler() 将您的配置类注册到 TorchAO 的量化系统中。

关键实现细节

硬件特定的线性操作

量化张量的正向传递决定了硬件支持以及当调用 torch.nn.functional.linear() 时实际调用的内容。

@MyQuantizedTensor.implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
    input_tensor, weight_tensor, bias = args[0], args[1], args[2] if len(args) > 2 else None

    # This is where you define what hardware your method supports
    if hasattr(weight_tensor, 'use_cutlass_kernel'):
        return my_cutlass_linear(input_tensor, weight_tensor, bias)
    elif hasattr(weight_tensor, 'use_triton_kernel'):
        return my_triton_linear(input_tensor, weight_tensor, bias)
    else:
        # Fallback - dequantize and use standard linear
        return torch.nn.functional.linear(
            input_tensor, weight_tensor.dequantize(), bias
        )

编译优势

张量子类的开销通过 torch.compile() 消失,这在 VLLM 中默认启用。

Tensor 子类的权衡

  1. 编译:对于消除子类开销至关重要。如果没有它,除非您的模型极度受 GPU 限制,否则 CPU 上的调度开销会严重影响性能。

  2. 检查点定义了模型的行为。您可能会说“所有检查点都是这样吗?”。这没错,但人们通常只将 torch.Tensor 视为其数据。而实际上,它是一个真正的类,它带来了调度器和 ATen 注册的所有内核。当您定义张量子类时,您正在构建一个独立的微型世界。它具有不同的数据表示,但您也需要明确定义支持哪些操作以及支持所有硬件的实现。这起初可能感觉有点像远距离的幽灵动作。但它可能非常强大。一个例子是仅通过 3 个定义就能支持 TP。

序列化和模型共享

SafeTensors 支持

当前状态:由于张量子类的限制,TorchAO 量化模型尚不能使用 safetensors 序列化。保存量化模型时,必须使用 safe_serialization=False

变通方法:对于生产使用,推送到 HuggingFace Hub 时,请使用 safe_serialization=False 保存模型。

未来工作:TorchAO 团队正在积极开发对张量子类的 safetensors 支持。跟踪进度:pytorch/ao#2338

集成架构图

1. 高级模型流:Transformers → VLLM + TorchAO

此图显示了从模型创建到提供服务的端到端流程

        graph LR
    A[HuggingFace Model] --> B[Transformers AutoModel]
    B --> C{Quantization Config?}
    C -->|TorchAO Config| D[Apply TorchAO Quantization]
    C -->|No Config| E[Standard Model]

    D --> F[Quantized Model w/ Tensor Subclasses]
    E --> G[Standard PyTorch Model]

    F --> H[VLLM Model Loading]
    G --> H

    H --> I[VLLM Distributed Engine]
    I --> J[Tensor Parallel Sharding]
    J --> K[Optimized Inference]

    style D fill:#e1f5fe
    style F fill:#f3e5f5
    style J fill:#e8f5e8
    

2. VLLM 中的 TorchAO 集成点

这显示了 VLLM 如何检测和应用 TorchAO 量化

        graph LR
    A[Model Config Detection] --> B{quantization=torchao?}
    B -->|Yes| C[TorchAOConfig.from_config]
    B -->|No| D[Other Quantization Methods]

    C --> E[Parse HF quant_type]
    E --> F[config_from_dict]
    F --> G[AOBaseConfig Instance]

    G --> H[get_quant_method per layer]
    H --> I{Layer Type?}
    I -->|LinearBase| J[TorchAOLinearMethod]
    I -->|Other| K[UnquantizedLinearMethod]

    J --> L[create_weights]
    L --> M[torchao_quantize_param_data]
    M --> N[Quantized Tensor Subclass]

    style C fill:#e1f5fe
    style G fill:#f3e5f5
    style N fill:#e8f5e8
    

3. 内核调度:将外部内核引入 VLLM

这说明了张量子类如何在 VLLM 中实现自定义内核调度

        graph LR
    A[F.linear Call in VLLM] --> B[MyQuantTensor torch_function]
    B --> C[Custom implements Handler]
    C --> D{Hardware Check}

    D --> E[Dispatch to External Kernel]
    E --> F[Execute Optimized Kernel]
    F --> G[Return Result to VLLM]

    subgraph "External Libraries"
        H[TorchAO CUTLASS]
        I[TorchAO Triton]
        J[FBGEMM-GPU]
        K[Custom Libraries]
    end

    subgraph "Tensor Subclass Code"
        L[implements F.linear]
        M[custom_linear_impl]
        N[call external kernel]
    end

    E --> H
    E --> I
    E --> J
    E --> K

    C --> L
    L --> M
    M --> N
    N --> E

    style B fill:#e8f6ff,color:#000
    style C fill:#fff3e0,color:#000
    style E fill:#e8f5e8,color:#000
    style L fill:#f3e5f5,color:#000
    

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源