注意
转到末尾 下载完整的示例代码。
修剪教程#
创建于: 2019年7月22日 | 最后更新: 2023年11月02日 | 最后验证: 2024年11月05日
作者: Michela Paganini
最先进的深度学习技术依赖于难以部署的过度参数化模型。相反,生物神经网络以使用高效的稀疏连接而闻名。识别通过减少模型中的参数数量来压缩模型的最佳技术非常重要,以减少内存、电池和硬件消耗,而又不牺牲准确性。这反过来又允许您在设备上部署轻量级模型,并通过私有的设备内计算来保证隐私。在研究方面,修剪被用于研究过度参数化和欠参数化网络之间学习动态的差异,研究幸运稀疏子网络和初始化的作用(“彩票”),作为一种破坏性的神经架构搜索技术,等等。
在本教程中,您将学习如何使用 torch.nn.utils.prune 来稀疏化您的神经网络,以及如何扩展它来实现您自己的自定义修剪技术。
要求#
"torch>=1.4.0a0+8e8a5e0"
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
创建模型#
在本教程中,我们使用了 LeCun 等人 (1998) 的 LeNet 架构。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
检查模块#
让我们检查一下 LeNet 模型中(未修剪的)conv1 层。它将包含两个参数 weight 和 bias,目前没有缓冲区。
module = model.conv1
print(list(module.named_parameters()))
[('weight', Parameter containing:
tensor([[[[ 0.1080, 0.1777, -0.0219, 0.1556, -0.1139],
[ 0.1726, -0.0548, -0.0320, 0.0731, -0.1833],
[-0.1037, 0.1749, 0.0621, 0.1258, -0.0840],
[ 0.1834, 0.0396, 0.1947, -0.0930, -0.1996],
[-0.1061, -0.0661, 0.0420, -0.1807, 0.1205]]],
[[[-0.1610, -0.1277, 0.0102, -0.0191, -0.0627],
[-0.1233, -0.0103, 0.0556, 0.0748, 0.1583],
[ 0.0701, -0.1686, 0.0733, -0.1530, -0.0384],
[-0.1136, -0.0863, 0.0755, -0.1585, -0.1921],
[-0.0318, 0.1514, 0.1999, 0.0979, 0.0559]]],
[[[ 0.0826, -0.1019, -0.1807, -0.0031, 0.1562],
[ 0.0134, 0.0204, -0.0599, -0.0034, 0.0462],
[ 0.1143, -0.0257, -0.0628, -0.1107, -0.0187],
[-0.1300, -0.1447, 0.0057, -0.0971, -0.1935],
[-0.1217, -0.1738, 0.1224, -0.1521, 0.0138]]],
[[[-0.0396, -0.1639, 0.1371, -0.1733, -0.0824],
[-0.0278, -0.1693, 0.0440, 0.1116, -0.0702],
[ 0.0930, -0.1650, 0.1249, -0.0173, -0.0074],
[ 0.1675, 0.0054, -0.1918, -0.0846, -0.0560],
[-0.1026, 0.1980, -0.1918, 0.0841, 0.1897]]],
[[[-0.0385, 0.1232, 0.1315, 0.1062, -0.0976],
[ 0.1838, -0.1291, 0.1153, 0.1173, 0.0644],
[-0.1098, -0.1352, 0.1762, 0.0470, 0.1758],
[ 0.1444, -0.1419, 0.1106, 0.0789, 0.0470],
[ 0.0996, 0.0549, 0.0470, 0.1610, 0.1657]]],
[[[ 0.0974, -0.1663, -0.1839, 0.1924, -0.0193],
[ 0.0538, 0.0496, -0.1254, 0.0740, -0.1996],
[-0.0378, 0.0121, 0.1558, -0.1539, -0.1766],
[-0.1681, 0.0488, 0.1711, 0.1994, -0.0155],
[-0.1179, 0.0486, 0.1481, -0.0658, -0.0872]]]], device='cuda:0',
requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993, 0.1291, -0.1555, -0.1203], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[]
修剪模块#
要修剪一个模块(在本例中是 LeNet 架构的 conv1 层),首先从 torch.nn.utils.prune 中提供的修剪技术中选择一种(或通过继承 BasePruningMethod 实现 您自己的)。然后,指定要修剪的模块和该模块内参数的名称。最后,使用所选修剪技术所需的适当关键字参数,指定修剪参数。
在此示例中,我们将随机修剪 conv1 层中名为 weight 的参数的 30% 的连接。模块作为函数的第一个参数传递;name 使用其字符串标识符标识该模块内的参数;amount 指示要修剪的连接的百分比(如果它是介于 0. 和 1. 之间的浮点数),或者要修剪的连接的绝对数量(如果它是一个非负整数)。
prune.random_unstructured(module, name="weight", amount=0.3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
修剪通过从参数中删除 weight 并用名为 weight_orig 的新参数替换它(即在初始参数 name 后面追加 "_orig")来起作用。weight_orig 存储了未修剪的张量版本。bias 未被修剪,因此它将保持不变。
print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993, 0.1291, -0.1555, -0.1203], device='cuda:0',
requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1080, 0.1777, -0.0219, 0.1556, -0.1139],
[ 0.1726, -0.0548, -0.0320, 0.0731, -0.1833],
[-0.1037, 0.1749, 0.0621, 0.1258, -0.0840],
[ 0.1834, 0.0396, 0.1947, -0.0930, -0.1996],
[-0.1061, -0.0661, 0.0420, -0.1807, 0.1205]]],
[[[-0.1610, -0.1277, 0.0102, -0.0191, -0.0627],
[-0.1233, -0.0103, 0.0556, 0.0748, 0.1583],
[ 0.0701, -0.1686, 0.0733, -0.1530, -0.0384],
[-0.1136, -0.0863, 0.0755, -0.1585, -0.1921],
[-0.0318, 0.1514, 0.1999, 0.0979, 0.0559]]],
[[[ 0.0826, -0.1019, -0.1807, -0.0031, 0.1562],
[ 0.0134, 0.0204, -0.0599, -0.0034, 0.0462],
[ 0.1143, -0.0257, -0.0628, -0.1107, -0.0187],
[-0.1300, -0.1447, 0.0057, -0.0971, -0.1935],
[-0.1217, -0.1738, 0.1224, -0.1521, 0.0138]]],
[[[-0.0396, -0.1639, 0.1371, -0.1733, -0.0824],
[-0.0278, -0.1693, 0.0440, 0.1116, -0.0702],
[ 0.0930, -0.1650, 0.1249, -0.0173, -0.0074],
[ 0.1675, 0.0054, -0.1918, -0.0846, -0.0560],
[-0.1026, 0.1980, -0.1918, 0.0841, 0.1897]]],
[[[-0.0385, 0.1232, 0.1315, 0.1062, -0.0976],
[ 0.1838, -0.1291, 0.1153, 0.1173, 0.0644],
[-0.1098, -0.1352, 0.1762, 0.0470, 0.1758],
[ 0.1444, -0.1419, 0.1106, 0.0789, 0.0470],
[ 0.0996, 0.0549, 0.0470, 0.1610, 0.1657]]],
[[[ 0.0974, -0.1663, -0.1839, 0.1924, -0.0193],
[ 0.0538, 0.0496, -0.1254, 0.0740, -0.1996],
[-0.0378, 0.0121, 0.1558, -0.1539, -0.1766],
[-0.1681, 0.0488, 0.1711, 0.1994, -0.0155],
[-0.1179, 0.0486, 0.1481, -0.0658, -0.0872]]]], device='cuda:0',
requires_grad=True))]
由上述修剪技术生成的修剪掩码将作为名为 weight_mask 的模块缓冲区保存(即在初始参数 name 后面追加 "_mask")。
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 1., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 0.],
[1., 0., 0., 1., 0.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 1., 1., 1.],
[0., 1., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 0., 1., 1., 1.],
[0., 0., 1., 1., 1.]]],
[[[1., 0., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 0., 0., 0., 1.],
[1., 1., 1., 0., 1.],
[0., 0., 1., 1., 1.]]],
[[[1., 0., 0., 1., 0.],
[1., 1., 0., 0., 1.],
[0., 0., 1., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 0., 1.]]],
[[[1., 0., 1., 0., 1.],
[1., 0., 0., 1., 0.],
[0., 1., 1., 1., 1.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 1., 0., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 0., 1., 1.]]]], device='cuda:0'))]
为了使前向传播能够正常工作而无需修改,需要存在 weight 属性。在 torch.nn.utils.prune 中实现的修剪技术会计算权重的修剪版本(通过将掩码与原始参数组合),并将它们存储在 weight 属性中。请注意,这不再是 module 的参数,现在它只是一个属性。
print(module.weight)
tensor([[[[ 0.0000, 0.1777, -0.0219, 0.0000, -0.1139],
[ 0.1726, -0.0548, -0.0320, 0.0731, -0.1833],
[-0.1037, 0.1749, 0.0621, 0.1258, -0.0000],
[ 0.1834, 0.0000, 0.0000, -0.0930, -0.0000],
[-0.1061, -0.0661, 0.0420, -0.1807, 0.1205]]],
[[[-0.0000, -0.0000, 0.0102, -0.0191, -0.0627],
[-0.0000, -0.0103, 0.0000, 0.0748, 0.1583],
[ 0.0701, -0.1686, 0.0733, -0.1530, -0.0384],
[-0.1136, -0.0000, 0.0755, -0.1585, -0.1921],
[-0.0000, 0.0000, 0.1999, 0.0979, 0.0559]]],
[[[ 0.0826, -0.0000, -0.1807, -0.0031, 0.1562],
[ 0.0134, 0.0204, -0.0599, -0.0034, 0.0462],
[ 0.1143, -0.0000, -0.0000, -0.0000, -0.0187],
[-0.1300, -0.1447, 0.0057, -0.0000, -0.1935],
[-0.0000, -0.0000, 0.1224, -0.1521, 0.0138]]],
[[[-0.0396, -0.0000, 0.0000, -0.1733, -0.0000],
[-0.0278, -0.1693, 0.0000, 0.0000, -0.0702],
[ 0.0000, -0.0000, 0.1249, -0.0000, -0.0074],
[ 0.1675, 0.0054, -0.1918, -0.0846, -0.0560],
[-0.1026, 0.1980, -0.1918, 0.0000, 0.1897]]],
[[[-0.0385, 0.0000, 0.1315, 0.0000, -0.0976],
[ 0.1838, -0.0000, 0.0000, 0.1173, 0.0000],
[-0.0000, -0.1352, 0.1762, 0.0470, 0.1758],
[ 0.1444, -0.1419, 0.1106, 0.0789, 0.0000],
[ 0.0996, 0.0549, 0.0470, 0.1610, 0.1657]]],
[[[ 0.0000, -0.0000, -0.0000, 0.1924, -0.0193],
[ 0.0000, 0.0000, -0.1254, 0.0740, -0.1996],
[-0.0000, 0.0121, 0.0000, -0.0000, -0.1766],
[-0.1681, 0.0488, 0.1711, 0.1994, -0.0155],
[-0.1179, 0.0486, 0.0000, -0.0658, -0.0872]]]], device='cuda:0',
grad_fn=<MulBackward0>)
最后,使用 PyTorch 的 forward_pre_hooks 在每次前向传播之前应用修剪。具体来说,当 module 被修剪时,正如我们在这里所做的那样,它将为与之关联的每个被修剪的参数获取一个 forward_pre_hook。在这种情况下,由于我们到目前为止只修剪了名为 weight 的原始参数,因此只会存在一个钩子。
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f5e194cb040>)])
为了完整起见,我们现在也可以修剪 bias,以了解模块的参数、缓冲区、钩子和属性如何变化。仅为了尝试另一种修剪技术,这里我们根据 L1 范数修剪 bias 中的 3 个最小的条目,这在 l1_unstructured 修剪函数中实现。
prune.l1_unstructured(module, name="bias", amount=3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
我们现在预计命名参数将包括 weight_orig(来自之前)和 bias_orig。缓冲区将包括 weight_mask 和 bias_mask。两个张量的修剪版本将作为模块属性存在,并且该模块现在将具有两个 forward_pre_hooks。
print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.1080, 0.1777, -0.0219, 0.1556, -0.1139],
[ 0.1726, -0.0548, -0.0320, 0.0731, -0.1833],
[-0.1037, 0.1749, 0.0621, 0.1258, -0.0840],
[ 0.1834, 0.0396, 0.1947, -0.0930, -0.1996],
[-0.1061, -0.0661, 0.0420, -0.1807, 0.1205]]],
[[[-0.1610, -0.1277, 0.0102, -0.0191, -0.0627],
[-0.1233, -0.0103, 0.0556, 0.0748, 0.1583],
[ 0.0701, -0.1686, 0.0733, -0.1530, -0.0384],
[-0.1136, -0.0863, 0.0755, -0.1585, -0.1921],
[-0.0318, 0.1514, 0.1999, 0.0979, 0.0559]]],
[[[ 0.0826, -0.1019, -0.1807, -0.0031, 0.1562],
[ 0.0134, 0.0204, -0.0599, -0.0034, 0.0462],
[ 0.1143, -0.0257, -0.0628, -0.1107, -0.0187],
[-0.1300, -0.1447, 0.0057, -0.0971, -0.1935],
[-0.1217, -0.1738, 0.1224, -0.1521, 0.0138]]],
[[[-0.0396, -0.1639, 0.1371, -0.1733, -0.0824],
[-0.0278, -0.1693, 0.0440, 0.1116, -0.0702],
[ 0.0930, -0.1650, 0.1249, -0.0173, -0.0074],
[ 0.1675, 0.0054, -0.1918, -0.0846, -0.0560],
[-0.1026, 0.1980, -0.1918, 0.0841, 0.1897]]],
[[[-0.0385, 0.1232, 0.1315, 0.1062, -0.0976],
[ 0.1838, -0.1291, 0.1153, 0.1173, 0.0644],
[-0.1098, -0.1352, 0.1762, 0.0470, 0.1758],
[ 0.1444, -0.1419, 0.1106, 0.0789, 0.0470],
[ 0.0996, 0.0549, 0.0470, 0.1610, 0.1657]]],
[[[ 0.0974, -0.1663, -0.1839, 0.1924, -0.0193],
[ 0.0538, 0.0496, -0.1254, 0.0740, -0.1996],
[-0.0378, 0.0121, 0.1558, -0.1539, -0.1766],
[-0.1681, 0.0488, 0.1711, 0.1994, -0.0155],
[-0.1179, 0.0486, 0.1481, -0.0658, -0.0872]]]], device='cuda:0',
requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993, 0.1291, -0.1555, -0.1203], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 1., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 0.],
[1., 0., 0., 1., 0.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 1., 1., 1.],
[0., 1., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 0., 1., 1., 1.],
[0., 0., 1., 1., 1.]]],
[[[1., 0., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 0., 0., 0., 1.],
[1., 1., 1., 0., 1.],
[0., 0., 1., 1., 1.]]],
[[[1., 0., 0., 1., 0.],
[1., 1., 0., 0., 1.],
[0., 0., 1., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 0., 1.]]],
[[[1., 0., 1., 0., 1.],
[1., 0., 0., 1., 0.],
[0., 1., 1., 1., 1.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 1., 0., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 0., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([1., 0., 1., 0., 1., 0.], device='cuda:0'))]
print(module.bias)
tensor([ 0.1364, -0.0000, -0.1993, 0.0000, -0.1555, -0.0000], device='cuda:0',
grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f5e194cb040>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f5e194cb0d0>)])
迭代修剪#
模块中的同一参数可以被多次修剪,各种修剪调用的效果等同于一系列应用各种掩码的组合。新掩码与旧掩码的组合由 PruningContainer 的 compute_mask 方法处理。
例如,假设我们现在想进一步修剪 module.weight,这次使用沿张量第 0 轴的结构化修剪(第 0 轴对应卷积层的输出通道,对于 conv1 维度为 6),基于通道的 L2 范数。这可以通过 ln_structured 函数实现,其中 n=2 和 dim=0。
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)
tensor([[[[ 0.0000, 0.1777, -0.0219, 0.0000, -0.1139],
[ 0.1726, -0.0548, -0.0320, 0.0731, -0.1833],
[-0.1037, 0.1749, 0.0621, 0.1258, -0.0000],
[ 0.1834, 0.0000, 0.0000, -0.0930, -0.0000],
[-0.1061, -0.0661, 0.0420, -0.1807, 0.1205]]],
[[[-0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000]]],
[[[-0.0396, -0.0000, 0.0000, -0.1733, -0.0000],
[-0.0278, -0.1693, 0.0000, 0.0000, -0.0702],
[ 0.0000, -0.0000, 0.1249, -0.0000, -0.0074],
[ 0.1675, 0.0054, -0.1918, -0.0846, -0.0560],
[-0.1026, 0.1980, -0.1918, 0.0000, 0.1897]]],
[[[-0.0385, 0.0000, 0.1315, 0.0000, -0.0976],
[ 0.1838, -0.0000, 0.0000, 0.1173, 0.0000],
[-0.0000, -0.1352, 0.1762, 0.0470, 0.1758],
[ 0.1444, -0.1419, 0.1106, 0.0789, 0.0000],
[ 0.0996, 0.0549, 0.0470, 0.1610, 0.1657]]],
[[[ 0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, -0.0000]]]], device='cuda:0',
grad_fn=<MulBackward0>)
相应的钩子现在将是 torch.nn.utils.prune.PruningContainer 类型,并将存储应用于 weight 参数的修剪历史。
[<torch.nn.utils.prune.RandomUnstructured object at 0x7f5e194cb040>, <torch.nn.utils.prune.LnStructured object at 0x7f5e194c88e0>]
序列化修剪后的模型#
所有相关的张量,包括掩码缓冲区和用于计算修剪后的张量的原始参数,都存储在模型的 state_dict 中,因此如果需要,可以轻松地进行序列化和保存。
print(model.state_dict().keys())
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
移除修剪重参数化#
为了使修剪永久化,移除关于 weight_orig 和 weight_mask 的重参数化,并移除 forward_pre_hook,我们可以使用 torch.nn.utils.prune 中的 remove 功能。请注意,这并不会撤消修剪,好像它从未发生过一样。相反,它通过将 weight 参数重新分配给模型参数(在其修剪后的版本中)来使其永久化。
在移除重参数化之前
print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.1080, 0.1777, -0.0219, 0.1556, -0.1139],
[ 0.1726, -0.0548, -0.0320, 0.0731, -0.1833],
[-0.1037, 0.1749, 0.0621, 0.1258, -0.0840],
[ 0.1834, 0.0396, 0.1947, -0.0930, -0.1996],
[-0.1061, -0.0661, 0.0420, -0.1807, 0.1205]]],
[[[-0.1610, -0.1277, 0.0102, -0.0191, -0.0627],
[-0.1233, -0.0103, 0.0556, 0.0748, 0.1583],
[ 0.0701, -0.1686, 0.0733, -0.1530, -0.0384],
[-0.1136, -0.0863, 0.0755, -0.1585, -0.1921],
[-0.0318, 0.1514, 0.1999, 0.0979, 0.0559]]],
[[[ 0.0826, -0.1019, -0.1807, -0.0031, 0.1562],
[ 0.0134, 0.0204, -0.0599, -0.0034, 0.0462],
[ 0.1143, -0.0257, -0.0628, -0.1107, -0.0187],
[-0.1300, -0.1447, 0.0057, -0.0971, -0.1935],
[-0.1217, -0.1738, 0.1224, -0.1521, 0.0138]]],
[[[-0.0396, -0.1639, 0.1371, -0.1733, -0.0824],
[-0.0278, -0.1693, 0.0440, 0.1116, -0.0702],
[ 0.0930, -0.1650, 0.1249, -0.0173, -0.0074],
[ 0.1675, 0.0054, -0.1918, -0.0846, -0.0560],
[-0.1026, 0.1980, -0.1918, 0.0841, 0.1897]]],
[[[-0.0385, 0.1232, 0.1315, 0.1062, -0.0976],
[ 0.1838, -0.1291, 0.1153, 0.1173, 0.0644],
[-0.1098, -0.1352, 0.1762, 0.0470, 0.1758],
[ 0.1444, -0.1419, 0.1106, 0.0789, 0.0470],
[ 0.0996, 0.0549, 0.0470, 0.1610, 0.1657]]],
[[[ 0.0974, -0.1663, -0.1839, 0.1924, -0.0193],
[ 0.0538, 0.0496, -0.1254, 0.0740, -0.1996],
[-0.0378, 0.0121, 0.1558, -0.1539, -0.1766],
[-0.1681, 0.0488, 0.1711, 0.1994, -0.0155],
[-0.1179, 0.0486, 0.1481, -0.0658, -0.0872]]]], device='cuda:0',
requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993, 0.1291, -0.1555, -0.1203], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 1., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 0.],
[1., 0., 0., 1., 0.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 0., 0., 1., 0.],
[1., 1., 0., 0., 1.],
[0., 0., 1., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 0., 1.]]],
[[[1., 0., 1., 0., 1.],
[1., 0., 0., 1., 0.],
[0., 1., 1., 1., 1.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([1., 0., 1., 0., 1., 0.], device='cuda:0'))]
print(module.weight)
tensor([[[[ 0.0000, 0.1777, -0.0219, 0.0000, -0.1139],
[ 0.1726, -0.0548, -0.0320, 0.0731, -0.1833],
[-0.1037, 0.1749, 0.0621, 0.1258, -0.0000],
[ 0.1834, 0.0000, 0.0000, -0.0930, -0.0000],
[-0.1061, -0.0661, 0.0420, -0.1807, 0.1205]]],
[[[-0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000]]],
[[[-0.0396, -0.0000, 0.0000, -0.1733, -0.0000],
[-0.0278, -0.1693, 0.0000, 0.0000, -0.0702],
[ 0.0000, -0.0000, 0.1249, -0.0000, -0.0074],
[ 0.1675, 0.0054, -0.1918, -0.0846, -0.0560],
[-0.1026, 0.1980, -0.1918, 0.0000, 0.1897]]],
[[[-0.0385, 0.0000, 0.1315, 0.0000, -0.0976],
[ 0.1838, -0.0000, 0.0000, 0.1173, 0.0000],
[-0.0000, -0.1352, 0.1762, 0.0470, 0.1758],
[ 0.1444, -0.1419, 0.1106, 0.0789, 0.0000],
[ 0.0996, 0.0549, 0.0470, 0.1610, 0.1657]]],
[[[ 0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, -0.0000]]]], device='cuda:0',
grad_fn=<MulBackward0>)
移除重参数化之后
prune.remove(module, 'weight')
print(list(module.named_parameters()))
[('bias_orig', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993, 0.1291, -0.1555, -0.1203], device='cuda:0',
requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000, 0.1777, -0.0219, 0.0000, -0.1139],
[ 0.1726, -0.0548, -0.0320, 0.0731, -0.1833],
[-0.1037, 0.1749, 0.0621, 0.1258, -0.0000],
[ 0.1834, 0.0000, 0.0000, -0.0930, -0.0000],
[-0.1061, -0.0661, 0.0420, -0.1807, 0.1205]]],
[[[-0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000]]],
[[[-0.0396, -0.0000, 0.0000, -0.1733, -0.0000],
[-0.0278, -0.1693, 0.0000, 0.0000, -0.0702],
[ 0.0000, -0.0000, 0.1249, -0.0000, -0.0074],
[ 0.1675, 0.0054, -0.1918, -0.0846, -0.0560],
[-0.1026, 0.1980, -0.1918, 0.0000, 0.1897]]],
[[[-0.0385, 0.0000, 0.1315, 0.0000, -0.0976],
[ 0.1838, -0.0000, 0.0000, 0.1173, 0.0000],
[-0.0000, -0.1352, 0.1762, 0.0470, 0.1758],
[ 0.1444, -0.1419, 0.1106, 0.0789, 0.0000],
[ 0.0996, 0.0549, 0.0470, 0.1610, 0.1657]]],
[[[ 0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, -0.0000]]]], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
[('bias_mask', tensor([1., 0., 1., 0., 1., 0.], device='cuda:0'))]
修剪模型中的多个参数#
通过指定所需的修剪技术和参数,我们可以轻松地修剪网络中的多个张量,也许根据它们的类型,正如我们将在本例中看到的。
new_model = LeNet()
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
全局修剪#
到目前为止,我们只看了通常被称为“局部”修剪的内容,即逐个修剪模型中的张量,仅将每个条目的统计数据(权重幅度、激活、梯度等)与其在张量中的其他条目进行比较。然而,一种常见且可能更强大的技术是同时修剪模型,例如,通过移除整个模型中最低的 20% 的连接,而不是移除每个层中最低的 20% 的连接。这可能会导致每层的修剪百分比不同。让我们看看如何使用 torch.nn.utils.prune 中的 global_unstructured 来实现这一点。
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
现在我们可以检查每个修剪参数中引起的稀疏性,它不会等于每层的 20%。然而,全局稀疏性将是(大约)20%。
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"Global sparsity: {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
Sparsity in conv1.weight: 6.67%
Sparsity in conv2.weight: 13.96%
Sparsity in fc1.weight: 22.14%
Sparsity in fc2.weight: 12.31%
Sparsity in fc3.weight: 9.64%
Global sparsity: 20.00%
使用自定义修剪函数扩展 torch.nn.utils.prune#
要实现自己的修剪函数,您可以像其他所有修剪方法一样,通过继承 BasePruningMethod 基类来扩展 nn.utils.prune 模块。基类为您实现了以下方法:__call__、apply_mask、apply、prune 和 remove。除了某些特殊情况,您无需为新的修剪技术重新实现这些方法。但是,您将需要实现 __init__(构造函数)和 compute_mask(根据您的修剪技术的逻辑计算给定张量掩码的说明)。此外,您必须指定此技术实现哪种类型的修剪(支持的选项是 global、structured 和 unstructured)。这对于确定如何在迭代应用修剪时组合掩码是必需的。换句话说,当修剪一个预先修剪过的参数时,当前修剪技术应该作用于参数中未修剪的部分。指定 PRUNING_TYPE 将使 PruningContainer(负责迭代应用修剪掩码)能够正确识别要修剪的参数切片。
例如,假设您想实现一种修剪技术,该技术修剪张量中的每隔一个条目(或者——如果张量先前已被修剪——则修剪张量中剩余未修剪部分中的每隔一个条目)。这将是 PRUNING_TYPE='unstructured',因为它作用于层中的单个连接,而不是作用于整个单元/通道('structured')或跨不同参数('global')。
class FooBarPruningMethod(prune.BasePruningMethod):
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
return mask
现在,要将其应用于 nn.Module 中的参数,您还应该提供一个简单的函数来实例化该方法并应用它。
def foobar_unstructured(module, name):
"""Prunes tensor corresponding to parameter called `name` in `module`
by removing every other entry in the tensors.
Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called `name+'_mask'` corresponding to the
binary mask applied to the parameter `name` by the pruning method.
The parameter `name` is replaced by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
`name+'_orig'`.
Args:
module (nn.Module): module containing the tensor to prune
name (string): parameter name within `module` on which pruning
will act.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input
module
Examples:
>>> m = nn.Linear(3, 4)
>>> foobar_unstructured(m, name='bias')
"""
FooBarPruningMethod.apply(module, name)
return module
让我们试试看!
model = LeNet()
foobar_unstructured(model.fc3, name='bias')
print(model.fc3.bias_mask)
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
脚本总运行时间: (0 分钟 0.559 秒)