与 VLLM 集成:架构和使用指南¶
本教程全面概述了 TorchAO 如何与 VLLM 集成,以及为使新技术端到端工作需要实现什么。
配置系统¶
1. HuggingFace 模型配置¶
TorchAO 量化通过模型的 config.json
文件配置
{
"model_type": "llama",
"quant_type": {
"default": {
"_type": "Int4WeightOnlyConfig",
"_data": {
"group_size": 128,
"use_hqq": true
}
}
}
}
2. TorchAO 配置类¶
所有量化方法都继承自 AOBaseConfig
from torchao.core.config import AOBaseConfig
from torchao.quantization import Int4WeightOnlyConfig
# Example configuration
config = Int4WeightOnlyConfig(
group_size=128,
use_hqq=True,
)
assert isinstance(config, AOBaseConfig)
注意
所有量化配置都继承自 torchao.core.config.AOBaseConfig
,该类提供序列化和验证功能。
3. 模块级配置¶
对于精细控制,请使用 ModuleFqnToConfig
from torchao.quantization import ModuleFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig
config = ModuleFqnToConfig({
"model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64),
"model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64),
"model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(),
"_default": Int4WeightOnlyConfig(group_size=128) # Default for other modules
})
使用示例¶
1. 使用 HuggingFace 集成量化模型¶
from transformers import TorchAoConfig, AutoModelForCausalLM
from torchao.quantization import Int4WeightOnlyConfig
# Create quantization configuration
quantization_config = TorchAoConfig(
quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True)
)
# Load and automatically quantize the model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
# Save quantized model (see Serialization section below for safe_serialization details)
model.push_to_hub("your-username/Llama-3.2-1B-int4", safe_serialization=False)
另请参阅
有关量化配置的更多信息,请参阅 torchao.quantization.Int4WeightOnlyConfig
和 torchao.quantization.Int8WeightOnlyConfig
。
2. 使用 VLLM 提供服务¶
# Start VLLM server with TorchAO quantized model
vllm serve your-username/Llama-3.2-1B-int4 \
--quantization torchao \
--dtype bfloat16 \
向 VLLM 添加新的量化方法¶
VLLM 兼容的最低要求¶
要使新的 TorchAO 量化方法与 VLLM 一起工作,您需要实现支持张量并行的最小张量子类操作。VLLM 使用 narrow()
和 copy_()
将状态字典中加载的主机 CPU 数据移动到设备,这些操作需要特定的 aten 操作。
为什么是这些?¶
VLLM 的张量并行工作原理如下:
一个有用的模式是 _apply_fn_to_data
,它将给定函数应用于类中所有具有 Tensor 类型的属性。下面是一个通用的实现,应该适用于大多数子类。我们在 torchao 代码库中大量使用这种模式
def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to all tensor components stored on this class"""
tensor_names, ctx = self.__tensor_flatten__()
# Apply the function to each tensor component
new_tensors = {}
for name in tensor_names:
new_tensors[name] = fn(getattr(self, name))
return self.__class__.__tensor_unflatten__(
new_tensors,
ctx,
None, # outer_size parameter
None, # outer_stride parameter
)
添加新量化方法的逐步指南¶
1. 创建您的 Tensor 子类¶
注意
有关张量子类及其设计原则的更多详细信息,请参阅 什么是张量子类? 文档。
from torchao.core.config import AOBaseConfig
from torchao.utils import TorchAOBaseTensor
@dataclass
class MyNewQuantConfig(AOBaseConfig):
"""Configuration for your new quantization method"""
bits: int = 8
VERSION: ClassVar[int] = 1
class MyQuantizedTensor(TorchAOBaseTensor):
"""Example based on FbgemmFp8Tensor - stores quantized data + scale"""
tensor_data_attrs = ["quantized_data", "scale"]
tensor_attributes = ["dtype"]
def __new__(cls, quantized_data, scale, dtype):
shape = quantized_data.shape
return torch.Tensor._make_wrapper_subclass(
cls, shape, device=quantized_data.device, dtype=dtype, requires_grad=False
)
def __init__(self, quantized_data, scale, dtype):
self.quantized_data = quantized_data
self.scale = scale
def __tensor_flatten__(self) -> Tuple[List[str], List]:
"""Serialize tensor subclass into plain tensors and metadata"""
return self.tensor_data_attrs, [
getattr(self, attr) for attr in self.tensor_attributes
]
@classmethod
def __tensor_unflatten__(
cls,
tensor_data_dict: Dict[str, torch.Tensor],
tensor_attributes: List,
outer_size: Optional[torch.Size],
outer_stride: Optional[Tuple],
) -> "MyQuantizedTensor":
"""Reconstruct tensor subclass from serialized data"""
return cls(
*[tensor_data_dict[name] for name in cls.tensor_data_attrs],
*tensor_attributes,
)
2. 实现所需的 VLLM 操作¶
from torch.utils._python_dispatch import return_and_correct_aliasing
@MyQuantizedTensor.implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(func)
)
@MyQuantizedTensor.implements([aten._to_copy.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
@MyQuantizedTensor.implements([aten.slice.Tensor])
def _(func, types, args, kwargs):
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0 or dim == 1:
# NOTE the slicing here will likely be different for different quant techniques
return return_and_correct_aliasing(
func, args, kwargs,
args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
else:
raise NotImplementedError(f"Slicing along dim={dim} not supported")
3. 在 TorchAO 的量化系统中注册¶
from torchao.quantization.transform_module import register_quantize_module_handler
@register_quantize_module_handler(MyNewQuantConfig)
def _my_quant_transform(module: torch.nn.Module, config: MyNewQuantConfig):
"""Transform function that applies your quantization to a module"""
weight = module.weight
# Your quantization logic here
quantized_weight = my_quantization_function(weight, config)
# Replace the weight with your quantized tensor
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
return module
重要提示
装饰器 torchao.quantization.transform_module.register_quantize_module_handler()
将您的配置类注册到 TorchAO 的量化系统中。
关键实现细节¶
硬件特定的线性操作¶
量化张量的正向传递决定了硬件支持以及当调用 torch.nn.functional.linear()
时实际调用的内容。
@MyQuantizedTensor.implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = args[0], args[1], args[2] if len(args) > 2 else None
# This is where you define what hardware your method supports
if hasattr(weight_tensor, 'use_cutlass_kernel'):
return my_cutlass_linear(input_tensor, weight_tensor, bias)
elif hasattr(weight_tensor, 'use_triton_kernel'):
return my_triton_linear(input_tensor, weight_tensor, bias)
else:
# Fallback - dequantize and use standard linear
return torch.nn.functional.linear(
input_tensor, weight_tensor.dequantize(), bias
)
编译优势¶
张量子类的开销通过 torch.compile()
消失,这在 VLLM 中默认启用。
Tensor 子类的权衡¶
编译:对于消除子类开销至关重要。如果没有它,除非您的模型极度受 GPU 限制,否则 CPU 上的调度开销会严重影响性能。
检查点定义了模型的行为。您可能会说“所有检查点都是这样吗?”。这没错,但人们通常只将 torch.Tensor 视为其数据。而实际上,它是一个真正的类,它带来了调度器和 ATen 注册的所有内核。当您定义张量子类时,您正在构建一个独立的微型世界。它具有不同的数据表示,但您也需要明确定义支持哪些操作以及支持所有硬件的实现。这起初可能感觉有点像远距离的幽灵动作。但它可能非常强大。一个例子是仅通过 3 个定义就能支持 TP。
序列化和模型共享¶
SafeTensors 支持¶
当前状态:由于张量子类的限制,TorchAO 量化模型尚不能使用 safetensors 序列化。保存量化模型时,必须使用 safe_serialization=False
。
变通方法:对于生产使用,推送到 HuggingFace Hub 时,请使用 safe_serialization=False
保存模型。
未来工作:TorchAO 团队正在积极开发对张量子类的 safetensors 支持。跟踪进度:pytorch/ao#2338
集成架构图¶
1. 高级模型流:Transformers → VLLM + TorchAO¶
此图显示了从模型创建到提供服务的端到端流程
graph LR A[HuggingFace Model] --> B[Transformers AutoModel] B --> C{Quantization Config?} C -->|TorchAO Config| D[Apply TorchAO Quantization] C -->|No Config| E[Standard Model] D --> F[Quantized Model w/ Tensor Subclasses] E --> G[Standard PyTorch Model] F --> H[VLLM Model Loading] G --> H H --> I[VLLM Distributed Engine] I --> J[Tensor Parallel Sharding] J --> K[Optimized Inference] style D fill:#e1f5fe style F fill:#f3e5f5 style J fill:#e8f5e8
2. VLLM 中的 TorchAO 集成点¶
这显示了 VLLM 如何检测和应用 TorchAO 量化
graph LR A[Model Config Detection] --> B{quantization=torchao?} B -->|Yes| C[TorchAOConfig.from_config] B -->|No| D[Other Quantization Methods] C --> E[Parse HF quant_type] E --> F[config_from_dict] F --> G[AOBaseConfig Instance] G --> H[get_quant_method per layer] H --> I{Layer Type?} I -->|LinearBase| J[TorchAOLinearMethod] I -->|Other| K[UnquantizedLinearMethod] J --> L[create_weights] L --> M[torchao_quantize_param_data] M --> N[Quantized Tensor Subclass] style C fill:#e1f5fe style G fill:#f3e5f5 style N fill:#e8f5e8
3. 内核调度:将外部内核引入 VLLM¶
这说明了张量子类如何在 VLLM 中实现自定义内核调度
graph LR A[F.linear Call in VLLM] --> B[MyQuantTensor torch_function] B --> C[Custom implements Handler] C --> D{Hardware Check} D --> E[Dispatch to External Kernel] E --> F[Execute Optimized Kernel] F --> G[Return Result to VLLM] subgraph "External Libraries" H[TorchAO CUTLASS] I[TorchAO Triton] J[FBGEMM-GPU] K[Custom Libraries] end subgraph "Tensor Subclass Code" L[implements F.linear] M[custom_linear_impl] N[call external kernel] end E --> H E --> I E --> J E --> K C --> L L --> M M --> N N --> E style B fill:#e8f6ff,color:#000 style C fill:#fff3e0,color:#000 style E fill:#e8f5e8,color:#000 style L fill:#f3e5f5,color:#000