(原型) 使用半结构化(2:4)稀疏性加速 BERT#
作者: Jesse Cai
与其他形式的稀疏性一样,半结构化稀疏性是一种模型优化技术,旨在以牺牲一些模型精度为代价来减少神经网络的内存开销和延迟。它也称为细粒度结构化稀疏性或2:4 结构化稀疏性。
半结构化稀疏性因其独特的稀疏模式而得名,其中在每 2n 个元素中有 n 个被剪枝。我们最常看到 n=2,因此是 2:4 稀疏性。半结构化稀疏性特别有趣,因为它可以有效地在 GPU 上加速,并且不会像其他稀疏模式那样严重降低模型精度。
通过引入半结构化稀疏性支持,无需离开 PyTorch 即可剪枝和加速半结构化稀疏模型。在本教程中,我们将解释此过程。

在本教程结束时,我们将使一个 BERT 问答模型稀疏化为 2:4 稀疏,并对其进行微调,以恢复几乎所有的 F1 损失(86.92 密集 vs 86.48 稀疏)。最后,我们将加速这个 2:4 稀疏模型进行推理,实现 1.3 倍的加速。
要求#
PyTorch >= 2.1。
支持半结构化稀疏性的 NVIDIA GPU(计算能力 8.0+)。
注意
本教程专为半结构化稀疏性/一般稀疏性的初学者设计。对于拥有现有 2:4 稀疏模型的用户来说,使用 to_sparse_semi_structured
为 nn.Linear
层加速推理就像:
import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
SparseSemiStructuredTensor._FORCE_CUTLASS = True
# mask Linear weight to be 2:4 sparse
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)
x = torch.rand(3072, 10240).half().cuda()
with torch.inference_mode():
dense_output = linear(x)
dense_t = Timer(stmt="linear(x)",
globals={"linear": linear,
"x": x}).blocked_autorange().median * 1e3
# accelerate via SparseSemiStructuredTensor
linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))
sparse_output = linear(x)
sparse_t = Timer(stmt="linear(x)",
globals={"linear": linear,
"x": x}).blocked_autorange().median * 1e3
# sparse and dense matmul are numerically equivalent
assert torch.allclose(sparse_output, dense_output, atol=1e-3)
print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
在 A100 80GB 上,我们看到:密集:0.870ms 稀疏:0.630ms | 加速:1.382x
半结构化稀疏性解决了什么问题?#
稀疏性的普遍动机很简单:如果您的网络中有零,您可以避免存储/计算这些参数。然而,稀疏性的具体细节很棘手。开箱即用地将参数归零不会影响我们模型的延迟/内存开销。
这是因为密集张量仍然包含被剪枝(零)的元素,密集矩阵乘法内核仍然会处理这些元素。为了实现性能提升,我们需要用稀疏内核替换密集内核,这些内核会跳过涉及被剪枝元素的计算。
为此,这些内核使用稀疏矩阵,这些矩阵不存储被剪枝的元素,而是以压缩格式存储指定的元素。
对于半结构化稀疏性,我们存储原始参数的一半以及一些关于元素排列方式的压缩元数据。
存在许多不同的稀疏布局,各有优缺点。2:4 半结构化稀疏布局特别有趣,原因有二:1. 与之前的稀疏格式不同,半结构化稀疏性旨在在 GPU 上高效加速。
2020 年,NVIDIA 通过其 Ampere 架构引入了对半结构化稀疏性的硬件支持,并通过 CUTLASS/cuSPARSELt 发布了快速稀疏内核。
同时,与稀疏格式相比,半结构化稀疏性对模型精度的影响通常较小,尤其是在考虑更高级的剪枝/微调方法时。NVIDIA 在其白皮书中表明,一种简单的范例,即一次性进行幅度剪枝以实现 2:4 稀疏,然后重新训练模型,可以获得几乎相同的模型精度。
半结构化稀疏性处于一个最佳点,在低得多的稀疏度(50%)下提供 2 倍(理论)加速,同时仍然足够精细以保持模型精度。
网络 |
数据集 |
指标 |
密集 FP16 |
稀疏 FP16 |
---|---|---|---|---|
ResNet-50 |
ImageNet |
Top-1 |
76.1 |
76.2 |
ResNeXt-101_32x8d |
ImageNet |
Top-1 |
79.3 |
79.3 |
Xception |
ImageNet |
Top-1 |
79.2 |
79.2 |
SSD-RN50 |
COCO2017 |
bbAP |
24.8 |
24.8 |
MaskRCNN-RN50 |
COCO2017 |
bbAP |
37.9 |
37.9 |
FairSeq Transformer |
EN-DE WMT14 |
BLEU |
28.2 |
28.5 |
BERT-Large |
SQuAD v1.1 |
F1 |
91.9 |
91.9 |
从工作流程的角度来看,半结构化稀疏性还具有一个额外的优势。由于稀疏度固定为 50%,因此更容易将模型稀疏化问题分解为两个不同的子问题:
精度 - 我们如何找到一组 2:4 稀疏权重,以最大限度地减少我们模型的精度下降?
性能 - 我们如何加速我们的 2:4 稀疏权重以进行推理并减少内存开销?
这两个问题之间的自然交接点是归零的密集张量。我们的推理解决方案旨在以这种格式压缩和加速张量。我们预计许多用户会提出自定义掩码解决方案,因为这是一个活跃的研究领域。
现在我们已经了解了更多关于半结构化稀疏性的信息,让我们将其应用于在问答任务 SQuAD 上训练的 BERT 模型。
简介和设置#
让我们从导入所有需要的包开始。
import collections
import datasets
import evaluate
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
import transformers
# force CUTLASS use if cuSPARSELt is not available
SparseSemiStructuredTensor._FORCE_CUTLASS = True
torch.manual_seed(100)
我们还需要定义一些特定于数据集/任务的辅助函数。这些函数是从这个 Hugging Face 课程改编而来的,作为参考。
def preprocess_validation_function(examples, tokenizer):
inputs = tokenizer(
[q.strip() for q in examples["question"]],
examples["context"],
max_length=384,
truncation="only_second",
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_map = inputs.pop("overflow_to_sample_mapping")
example_ids = []
for i in range(len(inputs["input_ids"])):
sample_idx = sample_map[i]
example_ids.append(examples["id"][sample_idx])
sequence_ids = inputs.sequence_ids(i)
offset = inputs["offset_mapping"][i]
inputs["offset_mapping"][i] = [
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
]
inputs["example_id"] = example_ids
return inputs
def preprocess_train_function(examples, tokenizer):
inputs = tokenizer(
[q.strip() for q in examples["question"]],
examples["context"],
max_length=384,
truncation="only_second",
return_offsets_mapping=True,
padding="max_length",
)
offset_mapping = inputs["offset_mapping"]
answers = examples["answers"]
start_positions = []
end_positions = []
for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):
start_char = answer["answer_start"][0]
end_char = start_char + len(answer["text"][0])
sequence_ids = inputs.sequence_ids(i)
# Find the start and end of the context
idx = 0
while sequence_ids[idx] != 1:
idx += 1
context_start = idx
while sequence_ids[idx] == 1:
idx += 1
context_end = idx - 1
# If the answer is not fully inside the context, label it (0, 0)
if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
start_positions.append(0)
end_positions.append(0)
else:
# Otherwise it's the start and end token positions
idx = context_start
while idx <= context_end and offset[idx][0] <= start_char:
idx += 1
start_positions.append(idx - 1)
idx = context_end
while idx >= context_start and offset[idx][1] >= end_char:
idx -= 1
end_positions.append(idx + 1)
inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs
def compute_metrics(start_logits, end_logits, features, examples):
n_best = 20
max_answer_length = 30
metric = evaluate.load("squad")
example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(features):
example_to_features[feature["example_id"]].append(idx)
predicted_answers = []
# for example in tqdm(examples):
for example in examples:
example_id = example["id"]
context = example["context"]
answers = []
# Loop through all features associated with that example
for feature_index in example_to_features[example_id]:
start_logit = start_logits[feature_index]
end_logit = end_logits[feature_index]
offsets = features[feature_index]["offset_mapping"]
start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# Skip answers that are not fully in the context
if offsets[start_index] is None or offsets[end_index] is None:
continue
# Skip answers with a length that is either < 0
# or > max_answer_length
if (
end_index < start_index
or end_index - start_index + 1 > max_answer_length
):
continue
answer = {
"text": context[
offsets[start_index][0] : offsets[end_index][1]
],
"logit_score": start_logit[start_index] + end_logit[end_index],
}
answers.append(answer)
# Select the answer with the best score
if len(answers) > 0:
best_answer = max(answers, key=lambda x: x["logit_score"])
predicted_answers.append(
{"id": example_id, "prediction_text": best_answer["text"]}
)
else:
predicted_answers.append({"id": example_id, "prediction_text": ""})
theoretical_answers = [
{"id": ex["id"], "answers": ex["answers"]} for ex in examples
]
return metric.compute(predictions=predicted_answers, references=theoretical_answers)
在这些函数定义好之后,我们只需要一个额外的辅助函数,它将帮助我们对模型进行基准测试。
def measure_execution_time(model, batch_sizes, dataset):
dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])
dataset_for_model.set_format("torch")
model.cuda()
batch_size_to_time_sec = {}
for batch_size in batch_sizes:
batch = {
k: dataset_for_model[k][:batch_size].to(model.device)
for k in dataset_for_model.column_names
}
with torch.inference_mode():
timer = benchmark.Timer(
stmt="model(**batch)", globals={"model": model, "batch": batch}
)
p50 = timer.blocked_autorange().median * 1000
batch_size_to_time_sec[batch_size] = p50
return batch_size_to_time_sec
我们将从加载模型和分词器开始,然后设置我们的数据集。
# load model
model_name = "bert-base-cased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)
print(f"Loading tokenizer: {model_name}")
print(f"Loading model: {model_name}")
# set up train and val dataset
squad_dataset = datasets.load_dataset("squad")
tokenized_squad_dataset = {}
tokenized_squad_dataset["train"] = squad_dataset["train"].map(
lambda x: preprocess_train_function(x, tokenizer), batched=True
)
tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(
lambda x: preprocess_validation_function(x, tokenizer),
batched=True,
remove_columns=squad_dataset["train"].column_names,
)
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)
接下来,我们将对我们的模型在 SQuAD 上进行快速的基线训练。此任务要求我们的模型在给定上下文(维基百科文章)中识别回答给定问题的跨度或文本段。运行以下代码,我得到了 86.9 的 F1 分数。这非常接近报告的 NVIDIA 分数,差异可能是由于 BERT-base 与 BERT-large 或微调超参数造成的。
training_args = transformers.TrainingArguments(
"trainer",
num_train_epochs=1,
lr_scheduler_type="constant",
per_device_train_batch_size=64,
per_device_eval_batch_size=512,
)
trainer = transformers.Trainer(
model,
training_args,
train_dataset=tokenized_squad_dataset["train"],
eval_dataset=tokenized_squad_dataset["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
# batch sizes to compare for eval
batch_sizes = [4, 16, 64, 256]
# 2:4 sparsity require fp16, so we cast here for a fair comparison
with torch.autocast("cuda"):
with torch.inference_mode():
predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
fp16_baseline = compute_metrics(
start_logits,
end_logits,
tokenized_squad_dataset["validation"],
squad_dataset["validation"],
)
fp16_time = measure_execution_time(
model,
batch_sizes,
tokenized_squad_dataset["validation"],
)
print("fp16", fp16_baseline)
print("cuda_fp16 time", fp16_time)
# fp16 {'exact_match': 78.53358561967833, 'f1': 86.9280493093186}
# cuda_fp16 time {4: 10.927572380751371, 16: 19.607915310189128, 64: 73.18846387788653, 256: 286.91255673766136}
将 BERT 剪枝为 2:4 稀疏#
现在我们有了基线,是时候剪枝 BERT 了。有许多不同的剪枝策略,但最常见的之一是幅度剪枝,它试图移除 L1 范数最低的权重。NVIDIA 在其所有结果中都使用了幅度剪枝,这是一种常见的基线。
为此,我们将使用 torch.ao.pruning
包,该包包含一个权重范数(幅度)稀疏器。这些稀疏器通过将掩码参数化应用于模型中的权重张量来工作。这允许它们通过掩盖被剪枝的权重来模拟稀疏性。
我们还需要决定将稀疏性应用于模型的哪些层,在本例中是所有 nn.Linear 层,但特定于任务的头部输出除外。因为半结构化稀疏性具有形状约束,而特定于任务的 nn.Linear 层不满足这些约束。
sparsifier = WeightNormSparsifier(
# apply sparsity to all blocks
sparsity_level=1.0,
# shape of 4 elemens is a block
sparse_block_shape=(1, 4),
# two zeros for every block of 4
zeros_per_block=2
)
# add to config if nn.Linear and in the BERT model.
sparse_config = [
{"tensor_fqn": f"{fqn}.weight"}
for fqn, module in model.named_modules()
if isinstance(module, nn.Linear) and "layer" in fqn
]
剪枝模型的第一步是插入参数化以掩盖模型权重。这通过 prepare 步骤完成。任何时候我们尝试访问 .weight
,我们都会得到 mask * weight
。
# Prepare the model, insert fake-sparsity parameterizations for training
sparsifier.prepare(model, sparse_config)
print(model.bert.encoder.layer[0].output)
# BertOutput(
# (dense): ParametrizedLinear(
# in_features=3072, out_features=768, bias=True
# (parametrizations): ModuleDict(
# (weight): ParametrizationList(
# (0-5): 6 x FakeSparsity()
# )
# )
# )
# (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
# (dropout): Dropout(p=0.1, inplace=False)
# )
然后,我们将执行一次剪枝。所有剪枝器都实现了一个 update_mask()
方法,该方法使用由剪枝器实现决定的逻辑来更新掩码。step 方法会调用这些 update_mask
函数来处理稀疏配置中指定的权重。
我们还将评估模型,以显示零样本剪枝(即不进行微调/重新训练的剪枝)的精度下降。
sparsifier.step()
with torch.autocast("cuda"):
with torch.inference_mode():
predictions = trainer.predict(tokenized_squad_dataset["validation"])
pruned = compute_metrics(
*predictions.predictions,
tokenized_squad_dataset["validation"],
squad_dataset["validation"],
)
print("pruned eval metrics:", pruned)
# pruned eval metrics: {'exact_match': 40.59602649006622, 'f1': 56.51610004515979}
在这种状态下,我们可以开始微调模型,更新不会被剪枝的元素,以更好地弥补精度损失。一旦达到满意状态,我们就可以调用 squash_mask
将掩码和权重融合在一起。这将删除参数化,我们得到一个零填充的 2:4 密集模型。
trainer.train()
sparsifier.squash_mask()
torch.set_printoptions(edgeitems=4)
print(model.bert.encoder.layer[0].intermediate.dense.weight)
# Parameter containing:
# tensor([[ 0.0000, -0.0237, 0.0000, 0.0130, ..., -0.0462, -0.0000, 0.0000, -0.0272],
# [ 0.0436, -0.0000, -0.0000, 0.0492, ..., -0.0000, 0.0844, 0.0340, -0.0000],
# [-0.0302, -0.0350, 0.0000, 0.0000, ..., 0.0303, 0.0175, -0.0000, 0.0000],
# [ 0.0000, -0.0000, -0.0529, 0.0327, ..., 0.0213, 0.0000, -0.0000, 0.0735],
# ...,
# [ 0.0000, -0.0000, -0.0258, -0.0239, ..., -0.0000, -0.0000, 0.0380, 0.0562],
# [-0.0432, -0.0000, 0.0000, -0.0598, ..., 0.0000, -0.0000, 0.0262 -0.0227],
# [ 0.0244, 0.0921, -0.0000, -0.0000, ..., -0.0000, -0.0784, 0.0000, 0.0761],
# [ 0.0000, 0.0225, -0.0395, -0.0000, ..., -0.0000, 0.0684, -0.0344, -0.0000]], device='cuda:0', requires_grad=True)
加速 2:4 稀疏模型进行推理 ——–i———————————— 现在我们有了一个这种格式的模型,我们可以像在快速入门指南中一样加速它进行推理。
model = model.cuda().half()
# accelerate for sparsity
for fqn, module in model.named_modules():
if isinstance(module, nn.Linear) and "layer" in fqn:
module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))
with torch.inference_mode():
predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
metrics_sparse = compute_metrics(
start_logits,
end_logits,
tokenized_squad_dataset["validation"],
squad_dataset["validation"],
)
print("sparse eval metrics: ", metrics_sparse)
sparse_perf = measure_execution_time(
model,
batch_sizes,
tokenized_squad_dataset["validation"],
)
print("sparse perf metrics: ", sparse_perf)
# sparse eval metrics: {'exact_match': 78.43897824030275, 'f1': 86.48718950090766}
# sparse perf metrics: {4: 12.621004460379481, 16: 15.368514601141214, 64: 58.702805917710066, 256: 244.19364519417286}
在幅度剪枝后重新训练模型已恢复了剪枝时损失的几乎所有 F1。同时,我们实现了 1.28 倍的加速(bs=16)。请注意,并非所有形状都适合性能改进。当批次大小较小且在计算稀疏内核上花费的时间有限时,它们可能比密集内核慢。
指标 |
fp16 |
2:4 稀疏 |
增量/加速 |
---|---|---|---|
精确匹配 (%) |
78.53 |
78.44 |
-0.09 |
F1 (%) |
86.93 |
86.49 |
-0.44 |
时间(bs=4) |
10.93 |
12.62 |
0.87x |
时间(bs=16) |
19.61 |
15.37 |
1.28x |
时间(bs=64) |
73.19 |
58.70 |
1.25x |
时间(bs=256) |
286.91 |
244.19 |
1.18x |
结论#
在本教程中,我们展示了如何将 BERT 剪枝为 2:4 稀疏,以及如何加速 2:4 稀疏模型进行推理。通过利用我们的 SparseSemiStructuredTensor 子类,我们实现了比 fp16 基线高 1.3 倍的加速。我们还通过微调 BERT 来恢复任何损失的 F1(86.92 密集 vs 86.48 稀疏)来展示了 2:4 稀疏的好处。