注意
跳转至页面底部 下载完整示例代码。
剪枝教程#
创建日期:2019年7月22日 | 最后更新:2023年11月2日 | 最后验证:2024年11月5日
最先进的深度学习技术依赖于过参数化的模型,这使得它们难以部署。相反,生物神经网络以高效的稀疏连接而闻名。确定压缩模型的最佳技术(通过减少参数数量)非常重要,这样可以在不牺牲准确性的前提下降低内存、电池和硬件消耗。这进而使您能够在设备上部署轻量级模型,并通过私有的端侧计算保证隐私。在研究领域,剪枝被用于研究过参数化网络和欠参数化网络之间学习动态的差异,研究幸运的稀疏子网络和初始化的作用(“彩票假设”),以及作为一种破坏性的神经架构搜索技术等等。
在本教程中,您将学习如何使用 torch.nn.utils.prune 对神经网络进行稀疏化,以及如何扩展它以实现您自己的自定义剪枝技术。
要求#
"torch>=1.4.0a0+8e8a5e0"
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
创建模型#
在本教程中,我们使用 LeCun 等人 1998 年提出的 LeNet 架构。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
检查模块#
让我们检查 LeNet 模型中(未剪枝的)conv1 层。目前它包含两个参数 weight 和 bias,且没有缓存(buffers)。
module = model.conv1
print(list(module.named_parameters()))
[('weight', Parameter containing:
tensor([[[[-0.1389, 0.1621, -0.0019, 0.1047, 0.1553],
[-0.0194, -0.1287, 0.0318, 0.1169, -0.0582],
[ 0.1399, -0.0526, -0.0593, 0.1251, -0.1715],
[ 0.1106, -0.0829, 0.1451, 0.0798, -0.1958],
[ 0.1881, 0.1865, 0.1482, -0.0724, -0.0923]]],
[[[ 0.0749, 0.0413, 0.1687, 0.0545, -0.0305],
[-0.0107, 0.0103, 0.1008, -0.1294, 0.1947],
[ 0.0650, -0.0872, -0.1064, -0.0575, -0.1948],
[-0.1886, -0.0136, -0.1128, -0.1035, 0.0527],
[ 0.1278, -0.1355, 0.1927, -0.0167, -0.0984]]],
[[[-0.0584, -0.0702, 0.0756, -0.1504, 0.0927],
[-0.0653, -0.1520, -0.0465, 0.1141, -0.0568],
[-0.0111, 0.1255, 0.0831, -0.1745, -0.1861],
[-0.0190, -0.0416, -0.1819, 0.0339, 0.0036],
[ 0.1123, -0.1063, -0.0012, -0.0120, 0.0581]]],
[[[ 0.0058, -0.1075, -0.0344, -0.0881, 0.1850],
[ 0.1102, 0.0123, 0.1129, -0.1784, -0.0745],
[ 0.1194, -0.1984, 0.0663, 0.1699, 0.1600],
[ 0.0657, 0.0185, 0.1566, 0.0521, -0.0311],
[ 0.1633, 0.0994, -0.0083, 0.1258, -0.1734]]],
[[[ 0.1243, -0.0997, -0.0125, -0.0211, 0.1585],
[-0.0679, -0.1805, 0.0228, -0.0113, -0.1980],
[-0.0058, 0.0726, 0.1565, -0.0715, 0.1499],
[ 0.0685, 0.1905, 0.0270, -0.0809, -0.1433],
[-0.0495, 0.1291, -0.0862, -0.1781, 0.0445]]],
[[[-0.1397, 0.1159, 0.0754, 0.0839, 0.1721],
[-0.0329, -0.1404, 0.0211, -0.1976, -0.1909],
[ 0.0103, 0.0937, 0.1190, 0.0065, -0.1166],
[ 0.0831, 0.1909, 0.1406, -0.1168, 0.1588],
[-0.1014, 0.1756, 0.1770, -0.0094, -0.0080]]]], device='cuda:0',
requires_grad=True)), ('bias', Parameter containing:
tensor([-0.0681, 0.1991, 0.1042, 0.0379, -0.0034, 0.0177], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[]
剪枝模块#
要剪枝一个模块(在本例中为我们 LeNet 架构的 conv1 层),首先从 torch.nn.utils.prune 中选择一种可用的剪枝技术(或通过继承 BasePruningMethod 实现 您自己的技术)。然后,指定模块以及要在该模块内剪枝的参数名称。最后,使用所选剪枝技术所需的适当关键字参数,指定剪枝参数。
在此示例中,我们将随机剪枝 conv1 层中名为 weight 的参数中 30% 的连接。模块作为第一个参数传递给函数;name 使用字符串标识符识别模块内的参数;amount 表示要剪枝的连接百分比(如果是 0 到 1 之间的浮点数),或者要剪枝的连接绝对数量(如果是非负整数)。
prune.random_unstructured(module, name="weight", amount=0.3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
剪枝通过从参数中删除 weight 并将其替换为名为 weight_orig 的新参数(即在初始参数 name 后附加 "_orig")来起作用。weight_orig 存储张量的未剪枝版本。bias 未被剪枝,因此保持不变。
print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([-0.0681, 0.1991, 0.1042, 0.0379, -0.0034, 0.0177], device='cuda:0',
requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.1389, 0.1621, -0.0019, 0.1047, 0.1553],
[-0.0194, -0.1287, 0.0318, 0.1169, -0.0582],
[ 0.1399, -0.0526, -0.0593, 0.1251, -0.1715],
[ 0.1106, -0.0829, 0.1451, 0.0798, -0.1958],
[ 0.1881, 0.1865, 0.1482, -0.0724, -0.0923]]],
[[[ 0.0749, 0.0413, 0.1687, 0.0545, -0.0305],
[-0.0107, 0.0103, 0.1008, -0.1294, 0.1947],
[ 0.0650, -0.0872, -0.1064, -0.0575, -0.1948],
[-0.1886, -0.0136, -0.1128, -0.1035, 0.0527],
[ 0.1278, -0.1355, 0.1927, -0.0167, -0.0984]]],
[[[-0.0584, -0.0702, 0.0756, -0.1504, 0.0927],
[-0.0653, -0.1520, -0.0465, 0.1141, -0.0568],
[-0.0111, 0.1255, 0.0831, -0.1745, -0.1861],
[-0.0190, -0.0416, -0.1819, 0.0339, 0.0036],
[ 0.1123, -0.1063, -0.0012, -0.0120, 0.0581]]],
[[[ 0.0058, -0.1075, -0.0344, -0.0881, 0.1850],
[ 0.1102, 0.0123, 0.1129, -0.1784, -0.0745],
[ 0.1194, -0.1984, 0.0663, 0.1699, 0.1600],
[ 0.0657, 0.0185, 0.1566, 0.0521, -0.0311],
[ 0.1633, 0.0994, -0.0083, 0.1258, -0.1734]]],
[[[ 0.1243, -0.0997, -0.0125, -0.0211, 0.1585],
[-0.0679, -0.1805, 0.0228, -0.0113, -0.1980],
[-0.0058, 0.0726, 0.1565, -0.0715, 0.1499],
[ 0.0685, 0.1905, 0.0270, -0.0809, -0.1433],
[-0.0495, 0.1291, -0.0862, -0.1781, 0.0445]]],
[[[-0.1397, 0.1159, 0.0754, 0.0839, 0.1721],
[-0.0329, -0.1404, 0.0211, -0.1976, -0.1909],
[ 0.0103, 0.0937, 0.1190, 0.0065, -0.1166],
[ 0.0831, 0.1909, 0.1406, -0.1168, 0.1588],
[-0.1014, 0.1756, 0.1770, -0.0094, -0.0080]]]], device='cuda:0',
requires_grad=True))]
由上述所选剪枝技术生成的剪枝掩码(mask)保存为一个名为 weight_mask 的模块缓存(即在初始参数 name 后附加 "_mask")。
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[1., 0., 0., 1., 1.]]],
[[[1., 0., 0., 0., 1.],
[1., 0., 1., 1., 1.],
[1., 0., 1., 1., 0.],
[0., 1., 1., 1., 1.],
[1., 0., 1., 0., 1.]]],
[[[1., 0., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 0., 0., 0., 1.],
[1., 1., 0., 1., 1.],
[1., 1., 1., 1., 0.]]],
[[[1., 1., 1., 0., 1.],
[1., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[1., 1., 1., 1., 1.],
[0., 1., 0., 1., 0.]]],
[[[1., 1., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 0., 0., 1.],
[1., 0., 1., 1., 1.],
[1., 1., 1., 0., 0.]]],
[[[0., 1., 0., 0., 1.],
[1., 0., 1., 1., 1.],
[0., 1., 1., 0., 0.],
[0., 0., 1., 1., 1.],
[1., 1., 1., 1., 1.]]]], device='cuda:0'))]
为了使前向传播无需修改即可工作,必须存在 weight 属性。torch.nn.utils.prune 中实现的剪枝技术计算权重的剪枝版本(通过将掩码与原始参数组合),并将其存储在 weight 属性中。请注意,这不再是 module 的参数,它现在只是一个属性。
print(module.weight)
tensor([[[[-0.1389, 0.1621, -0.0000, 0.1047, 0.1553],
[-0.0194, -0.1287, 0.0318, 0.1169, -0.0582],
[ 0.1399, -0.0526, -0.0593, 0.1251, -0.1715],
[ 0.0000, -0.0829, 0.1451, 0.0798, -0.1958],
[ 0.1881, 0.0000, 0.0000, -0.0724, -0.0923]]],
[[[ 0.0749, 0.0000, 0.0000, 0.0000, -0.0305],
[-0.0107, 0.0000, 0.1008, -0.1294, 0.1947],
[ 0.0650, -0.0000, -0.1064, -0.0575, -0.0000],
[-0.0000, -0.0136, -0.1128, -0.1035, 0.0527],
[ 0.1278, -0.0000, 0.1927, -0.0000, -0.0984]]],
[[[-0.0584, -0.0000, 0.0756, -0.1504, 0.0927],
[-0.0653, -0.1520, -0.0465, 0.1141, -0.0568],
[-0.0111, 0.0000, 0.0000, -0.0000, -0.1861],
[-0.0190, -0.0416, -0.0000, 0.0339, 0.0036],
[ 0.1123, -0.1063, -0.0012, -0.0120, 0.0000]]],
[[[ 0.0058, -0.1075, -0.0344, -0.0000, 0.1850],
[ 0.1102, 0.0000, 0.0000, -0.0000, -0.0745],
[ 0.0000, -0.0000, 0.0000, 0.0000, 0.1600],
[ 0.0657, 0.0185, 0.1566, 0.0521, -0.0311],
[ 0.0000, 0.0994, -0.0000, 0.1258, -0.0000]]],
[[[ 0.1243, -0.0997, -0.0000, -0.0211, 0.1585],
[-0.0679, -0.1805, 0.0228, -0.0113, -0.1980],
[-0.0058, 0.0726, 0.0000, -0.0000, 0.1499],
[ 0.0685, 0.0000, 0.0270, -0.0809, -0.1433],
[-0.0495, 0.1291, -0.0862, -0.0000, 0.0000]]],
[[[-0.0000, 0.1159, 0.0000, 0.0000, 0.1721],
[-0.0329, -0.0000, 0.0211, -0.1976, -0.1909],
[ 0.0000, 0.0937, 0.1190, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.1406, -0.1168, 0.1588],
[-0.1014, 0.1756, 0.1770, -0.0094, -0.0080]]]], device='cuda:0',
grad_fn=<MulBackward0>)
最后,在每次前向传播之前,使用 PyTorch 的 forward_pre_hooks 应用剪枝。具体来说,当 module 被剪枝时(如我们此处所做的那样),它将为其关联的每个被剪枝的参数获取一个 forward_pre_hook。在本例中,由于我们目前只剪枝了名为 weight 的原始参数,因此只会存在一个钩子。
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f20673df790>)])
为了完整起见,我们现在也可以剪枝 bias,看看模块的参数、缓存、钩子和属性是如何变化的。为了尝试另一种剪枝技术,这里我们通过 L1 范数剪枝偏差中 3 个最小的条目,这是在 l1_unstructured 剪枝函数中实现的。
prune.l1_unstructured(module, name="bias", amount=3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
我们现在预计命名参数将包括 weight_orig(之前已存在)和 bias_orig。缓存将包括 weight_mask 和 bias_mask。这两个张量的剪枝版本将作为模块属性存在,并且该模块现在将有两个 forward_pre_hooks。
print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[-0.1389, 0.1621, -0.0019, 0.1047, 0.1553],
[-0.0194, -0.1287, 0.0318, 0.1169, -0.0582],
[ 0.1399, -0.0526, -0.0593, 0.1251, -0.1715],
[ 0.1106, -0.0829, 0.1451, 0.0798, -0.1958],
[ 0.1881, 0.1865, 0.1482, -0.0724, -0.0923]]],
[[[ 0.0749, 0.0413, 0.1687, 0.0545, -0.0305],
[-0.0107, 0.0103, 0.1008, -0.1294, 0.1947],
[ 0.0650, -0.0872, -0.1064, -0.0575, -0.1948],
[-0.1886, -0.0136, -0.1128, -0.1035, 0.0527],
[ 0.1278, -0.1355, 0.1927, -0.0167, -0.0984]]],
[[[-0.0584, -0.0702, 0.0756, -0.1504, 0.0927],
[-0.0653, -0.1520, -0.0465, 0.1141, -0.0568],
[-0.0111, 0.1255, 0.0831, -0.1745, -0.1861],
[-0.0190, -0.0416, -0.1819, 0.0339, 0.0036],
[ 0.1123, -0.1063, -0.0012, -0.0120, 0.0581]]],
[[[ 0.0058, -0.1075, -0.0344, -0.0881, 0.1850],
[ 0.1102, 0.0123, 0.1129, -0.1784, -0.0745],
[ 0.1194, -0.1984, 0.0663, 0.1699, 0.1600],
[ 0.0657, 0.0185, 0.1566, 0.0521, -0.0311],
[ 0.1633, 0.0994, -0.0083, 0.1258, -0.1734]]],
[[[ 0.1243, -0.0997, -0.0125, -0.0211, 0.1585],
[-0.0679, -0.1805, 0.0228, -0.0113, -0.1980],
[-0.0058, 0.0726, 0.1565, -0.0715, 0.1499],
[ 0.0685, 0.1905, 0.0270, -0.0809, -0.1433],
[-0.0495, 0.1291, -0.0862, -0.1781, 0.0445]]],
[[[-0.1397, 0.1159, 0.0754, 0.0839, 0.1721],
[-0.0329, -0.1404, 0.0211, -0.1976, -0.1909],
[ 0.0103, 0.0937, 0.1190, 0.0065, -0.1166],
[ 0.0831, 0.1909, 0.1406, -0.1168, 0.1588],
[-0.1014, 0.1756, 0.1770, -0.0094, -0.0080]]]], device='cuda:0',
requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0681, 0.1991, 0.1042, 0.0379, -0.0034, 0.0177], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[1., 0., 0., 1., 1.]]],
[[[1., 0., 0., 0., 1.],
[1., 0., 1., 1., 1.],
[1., 0., 1., 1., 0.],
[0., 1., 1., 1., 1.],
[1., 0., 1., 0., 1.]]],
[[[1., 0., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 0., 0., 0., 1.],
[1., 1., 0., 1., 1.],
[1., 1., 1., 1., 0.]]],
[[[1., 1., 1., 0., 1.],
[1., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[1., 1., 1., 1., 1.],
[0., 1., 0., 1., 0.]]],
[[[1., 1., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 0., 0., 1.],
[1., 0., 1., 1., 1.],
[1., 1., 1., 0., 0.]]],
[[[0., 1., 0., 0., 1.],
[1., 0., 1., 1., 1.],
[0., 1., 1., 0., 0.],
[0., 0., 1., 1., 1.],
[1., 1., 1., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([1., 1., 1., 0., 0., 0.], device='cuda:0'))]
print(module.bias)
tensor([-0.0681, 0.1991, 0.1042, 0.0000, -0.0000, 0.0000], device='cuda:0',
grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f20673df790>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f20673df8b0>)])
迭代剪枝#
模块中的同一个参数可以被多次剪枝,各种剪枝调用的效果等同于串联应用的各种掩码的组合。新掩码与旧掩码的组合由 PruningContainer 的 compute_mask 方法处理。
例如,假设我们现在想要进一步剪枝 module.weight,这次使用基于通道 L2 范数的结构化剪枝(沿张量的第 0 轴,第 0 轴对应卷积层的输出通道,对于 conv1 其维度为 6)。这可以使用 ln_structured 函数实现,设置 n=2 和 dim=0。
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)
tensor([[[[-0.1389, 0.1621, -0.0000, 0.1047, 0.1553],
[-0.0194, -0.1287, 0.0318, 0.1169, -0.0582],
[ 0.1399, -0.0526, -0.0593, 0.1251, -0.1715],
[ 0.0000, -0.0829, 0.1451, 0.0798, -0.1958],
[ 0.1881, 0.0000, 0.0000, -0.0724, -0.0923]]],
[[[ 0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, -0.0000, -0.0000]]],
[[[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, -0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000]]],
[[[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000]]],
[[[ 0.1243, -0.0997, -0.0000, -0.0211, 0.1585],
[-0.0679, -0.1805, 0.0228, -0.0113, -0.1980],
[-0.0058, 0.0726, 0.0000, -0.0000, 0.1499],
[ 0.0685, 0.0000, 0.0270, -0.0809, -0.1433],
[-0.0495, 0.1291, -0.0862, -0.0000, 0.0000]]],
[[[-0.0000, 0.1159, 0.0000, 0.0000, 0.1721],
[-0.0329, -0.0000, 0.0211, -0.1976, -0.1909],
[ 0.0000, 0.0937, 0.1190, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.1406, -0.1168, 0.1588],
[-0.1014, 0.1756, 0.1770, -0.0094, -0.0080]]]], device='cuda:0',
grad_fn=<MulBackward0>)
相应的钩子现在的类型将是 torch.nn.utils.prune.PruningContainer,并将存储应用于 weight 参数的剪枝历史。
[<torch.nn.utils.prune.RandomUnstructured object at 0x7f20673df790>, <torch.nn.utils.prune.LnStructured object at 0x7f20673dd570>]
序列化已剪枝的模型#
所有相关张量,包括掩码缓存和用于计算剪枝张量的原始参数,都存储在模型的 state_dict 中,因此如果需要,可以轻松地序列化和保存。
print(model.state_dict().keys())
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
移除剪枝重参数化#
为了使剪枝永久化,移除 weight_orig 和 weight_mask 方面的重参数化,并移除 forward_pre_hook,我们可以使用 torch.nn.utils.prune 中的 remove 功能。请注意,这并不会撤销剪枝。相反,它通过将参数 weight 重新分配为其剪枝后的版本,使其永久生效。
移除重参数化之前
print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[-0.1389, 0.1621, -0.0019, 0.1047, 0.1553],
[-0.0194, -0.1287, 0.0318, 0.1169, -0.0582],
[ 0.1399, -0.0526, -0.0593, 0.1251, -0.1715],
[ 0.1106, -0.0829, 0.1451, 0.0798, -0.1958],
[ 0.1881, 0.1865, 0.1482, -0.0724, -0.0923]]],
[[[ 0.0749, 0.0413, 0.1687, 0.0545, -0.0305],
[-0.0107, 0.0103, 0.1008, -0.1294, 0.1947],
[ 0.0650, -0.0872, -0.1064, -0.0575, -0.1948],
[-0.1886, -0.0136, -0.1128, -0.1035, 0.0527],
[ 0.1278, -0.1355, 0.1927, -0.0167, -0.0984]]],
[[[-0.0584, -0.0702, 0.0756, -0.1504, 0.0927],
[-0.0653, -0.1520, -0.0465, 0.1141, -0.0568],
[-0.0111, 0.1255, 0.0831, -0.1745, -0.1861],
[-0.0190, -0.0416, -0.1819, 0.0339, 0.0036],
[ 0.1123, -0.1063, -0.0012, -0.0120, 0.0581]]],
[[[ 0.0058, -0.1075, -0.0344, -0.0881, 0.1850],
[ 0.1102, 0.0123, 0.1129, -0.1784, -0.0745],
[ 0.1194, -0.1984, 0.0663, 0.1699, 0.1600],
[ 0.0657, 0.0185, 0.1566, 0.0521, -0.0311],
[ 0.1633, 0.0994, -0.0083, 0.1258, -0.1734]]],
[[[ 0.1243, -0.0997, -0.0125, -0.0211, 0.1585],
[-0.0679, -0.1805, 0.0228, -0.0113, -0.1980],
[-0.0058, 0.0726, 0.1565, -0.0715, 0.1499],
[ 0.0685, 0.1905, 0.0270, -0.0809, -0.1433],
[-0.0495, 0.1291, -0.0862, -0.1781, 0.0445]]],
[[[-0.1397, 0.1159, 0.0754, 0.0839, 0.1721],
[-0.0329, -0.1404, 0.0211, -0.1976, -0.1909],
[ 0.0103, 0.0937, 0.1190, 0.0065, -0.1166],
[ 0.0831, 0.1909, 0.1406, -0.1168, 0.1588],
[-0.1014, 0.1756, 0.1770, -0.0094, -0.0080]]]], device='cuda:0',
requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0681, 0.1991, 0.1042, 0.0379, -0.0034, 0.0177], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[1., 0., 0., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 1., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 0., 0., 1.],
[1., 0., 1., 1., 1.],
[1., 1., 1., 0., 0.]]],
[[[0., 1., 0., 0., 1.],
[1., 0., 1., 1., 1.],
[0., 1., 1., 0., 0.],
[0., 0., 1., 1., 1.],
[1., 1., 1., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([1., 1., 1., 0., 0., 0.], device='cuda:0'))]
print(module.weight)
tensor([[[[-0.1389, 0.1621, -0.0000, 0.1047, 0.1553],
[-0.0194, -0.1287, 0.0318, 0.1169, -0.0582],
[ 0.1399, -0.0526, -0.0593, 0.1251, -0.1715],
[ 0.0000, -0.0829, 0.1451, 0.0798, -0.1958],
[ 0.1881, 0.0000, 0.0000, -0.0724, -0.0923]]],
[[[ 0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, -0.0000, -0.0000]]],
[[[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, -0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000]]],
[[[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000]]],
[[[ 0.1243, -0.0997, -0.0000, -0.0211, 0.1585],
[-0.0679, -0.1805, 0.0228, -0.0113, -0.1980],
[-0.0058, 0.0726, 0.0000, -0.0000, 0.1499],
[ 0.0685, 0.0000, 0.0270, -0.0809, -0.1433],
[-0.0495, 0.1291, -0.0862, -0.0000, 0.0000]]],
[[[-0.0000, 0.1159, 0.0000, 0.0000, 0.1721],
[-0.0329, -0.0000, 0.0211, -0.1976, -0.1909],
[ 0.0000, 0.0937, 0.1190, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.1406, -0.1168, 0.1588],
[-0.1014, 0.1756, 0.1770, -0.0094, -0.0080]]]], device='cuda:0',
grad_fn=<MulBackward0>)
移除重参数化之后
prune.remove(module, 'weight')
print(list(module.named_parameters()))
[('bias_orig', Parameter containing:
tensor([-0.0681, 0.1991, 0.1042, 0.0379, -0.0034, 0.0177], device='cuda:0',
requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.1389, 0.1621, -0.0000, 0.1047, 0.1553],
[-0.0194, -0.1287, 0.0318, 0.1169, -0.0582],
[ 0.1399, -0.0526, -0.0593, 0.1251, -0.1715],
[ 0.0000, -0.0829, 0.1451, 0.0798, -0.1958],
[ 0.1881, 0.0000, 0.0000, -0.0724, -0.0923]]],
[[[ 0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, -0.0000, -0.0000]]],
[[[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, -0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000]]],
[[[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000]]],
[[[ 0.1243, -0.0997, -0.0000, -0.0211, 0.1585],
[-0.0679, -0.1805, 0.0228, -0.0113, -0.1980],
[-0.0058, 0.0726, 0.0000, -0.0000, 0.1499],
[ 0.0685, 0.0000, 0.0270, -0.0809, -0.1433],
[-0.0495, 0.1291, -0.0862, -0.0000, 0.0000]]],
[[[-0.0000, 0.1159, 0.0000, 0.0000, 0.1721],
[-0.0329, -0.0000, 0.0211, -0.1976, -0.1909],
[ 0.0000, 0.0937, 0.1190, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.1406, -0.1168, 0.1588],
[-0.1014, 0.1756, 0.1770, -0.0094, -0.0080]]]], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[('bias_mask', tensor([1., 1., 1., 0., 0., 0.], device='cuda:0'))]
剪枝模型中的多个参数#
通过指定所需的剪枝技术和参数,我们可以轻松地剪枝网络中的多个张量,例如根据它们的类型进行剪枝,我们将在本示例中看到。
new_model = LeNet()
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
全局剪枝#
到目前为止,我们只关注了通常所说的“局部”剪枝,即通过比较每个张量中各条目的统计数据(权重大小、激活值、梯度等)与该张量中其他条目的统计数据,逐个剪枝模型中张量的做法。然而,一种常见且可能更强大的技术是一次性剪枝整个模型,例如删除整个模型中 20% 的最小连接,而不是删除每一层中 20% 的最小连接。这很可能会导致每层的剪枝百分比不同。让我们看看如何使用 torch.nn.utils.prune 中的 global_unstructured 来实现这一点。
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
现在我们可以检查每个已剪枝参数中诱导出的稀疏性,它在每一层中不会等于 20%。但是,全局稀疏度将(大约)为 20%。
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"Global sparsity: {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
Sparsity in conv1.weight: 5.33%
Sparsity in conv2.weight: 13.21%
Sparsity in fc1.weight: 22.14%
Sparsity in fc2.weight: 12.35%
Sparsity in fc3.weight: 11.43%
Global sparsity: 20.00%
使用自定义剪枝函数扩展 torch.nn.utils.prune#
要实现您自己的剪枝函数,您可以像所有其他剪枝方法一样,通过继承 BasePruningMethod 基类来扩展 nn.utils.prune 模块。基类为您实现了以下方法:__call__, apply_mask, apply, prune 和 remove。除了某些特殊情况外,您无需为新的剪枝技术重新实现这些方法。但是,您必须实现 __init__(构造函数)和 compute_mask(根据您的剪枝技术逻辑计算给定张量掩码的指令)。此外,您还必须指定此技术实现的剪枝类型(支持的选项为 global, structured 和 unstructured)。这对于确定在迭代应用剪枝时如何组合掩码是必需的。换句话说,当剪枝一个已预剪枝的参数时,当前的剪枝技术预计会对参数的未剪枝部分进行操作。指定 PRUNING_TYPE 将使 PruningContainer(处理剪枝掩码的迭代应用)能够正确识别要剪枝的参数切片。
例如,假设您想要实现一种剪枝技术,该技术剪枝张量中的每隔一个条目(或者——如果张量之前已被剪枝——则剪枝剩余未剪枝部分的条目)。这将属于 PRUNING_TYPE='unstructured',因为它作用于层中的单个连接,而不是整个单元/通道('structured')或跨不同的参数('global')。
class FooBarPruningMethod(prune.BasePruningMethod):
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
return mask
现在,要将其应用于 nn.Module 中的参数,您还应该提供一个简单的函数来实例化该方法并应用它。
def foobar_unstructured(module, name):
"""Prunes tensor corresponding to parameter called `name` in `module`
by removing every other entry in the tensors.
Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called `name+'_mask'` corresponding to the
binary mask applied to the parameter `name` by the pruning method.
The parameter `name` is replaced by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
`name+'_orig'`.
Args:
module (nn.Module): module containing the tensor to prune
name (string): parameter name within `module` on which pruning
will act.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input
module
Examples:
>>> m = nn.Linear(3, 4)
>>> foobar_unstructured(m, name='bias')
"""
FooBarPruningMethod.apply(module, name)
return module
让我们试一试!
model = LeNet()
foobar_unstructured(model.fc3, name='bias')
print(model.fc3.bias_mask)
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
脚本总运行时间:(0 分钟 0.539 秒)