评价此页

剪枝教程#

创建于:2019年7月22日 | 最后更新:2023年11月2日 | 最后验证:2024年11月5日

作者: Michela Paganini

最先进的深度学习技术依赖于过度参数化的模型,这些模型难以部署。相反,生物神经网络已知使用高效的稀疏连接。识别优化模型压缩技术,通过减少模型中的参数数量,以在不牺牲准确性的前提下减少内存、电池和硬件消耗非常重要。这反过来允许您在设备上部署轻量级模型,并保证通过私有设备上计算的隐私。在研究方面,剪枝被用于研究过度参数化和欠参数化网络之间学习动态的差异,研究幸运稀疏子网络和初始化的作用(“彩票票”),作为一种破坏性的神经架构搜索技术,以及更多。

在本教程中,您将学习如何使用 torch.nn.utils.prune 来稀疏化您的神经网络,以及如何扩展它以实现您自己的自定义剪枝技术。

要求#

"torch>=1.4.0a0+8e8a5e0"

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

创建模型#

在本教程中,我们使用 LeCun 等人,1998 年的 LeNet 架构。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查一个模块#

让我们检查我们的 LeNet 模型中未剪枝的 conv1 层。它将包含两个参数 weightbias,目前没有缓冲区。

[('weight', Parameter containing:
tensor([[[[-0.0887,  0.0490,  0.0977,  0.0485, -0.1110],
          [ 0.0609,  0.1650,  0.0142,  0.0328, -0.1885],
          [ 0.1201, -0.1949, -0.0078,  0.1947, -0.0317],
          [-0.1072, -0.0950,  0.0800,  0.1870, -0.1998],
          [-0.1034,  0.0637,  0.1004,  0.1067,  0.1966]]],


        [[[-0.1798, -0.0160,  0.1653, -0.1746, -0.1187],
          [-0.0332, -0.1939, -0.0348,  0.0954,  0.0174],
          [ 0.0709,  0.0351, -0.0204, -0.1408,  0.1956],
          [-0.1881, -0.1463, -0.0860, -0.0436, -0.0996],
          [ 0.1238,  0.1071, -0.1757,  0.1637, -0.1336]]],


        [[[ 0.1670, -0.0898,  0.0706,  0.1766,  0.0581],
          [ 0.0269, -0.0498,  0.0255,  0.0530, -0.1791],
          [-0.1545, -0.1988,  0.1100, -0.0505, -0.1935],
          [ 0.0711,  0.1246, -0.0768, -0.1010, -0.0196],
          [ 0.0227,  0.0828, -0.1782,  0.0359,  0.1211]]],


        [[[-0.1681,  0.1553,  0.0715,  0.1201,  0.0706],
          [ 0.1194, -0.0622, -0.0143, -0.1870, -0.0589],
          [ 0.1232,  0.1836,  0.0275, -0.0334, -0.1969],
          [ 0.1895, -0.0932, -0.1651,  0.0179, -0.0509],
          [-0.0377,  0.0583,  0.1905, -0.0031, -0.0729]]],


        [[[ 0.1916,  0.0930,  0.1735,  0.0792,  0.0977],
          [ 0.0440, -0.1013,  0.1709, -0.0262, -0.1991],
          [-0.1626,  0.1974, -0.0097,  0.1050,  0.0525],
          [ 0.1470,  0.0845, -0.1195,  0.1524,  0.0614],
          [ 0.1687, -0.0960,  0.1612, -0.1088,  0.0619]]],


        [[[ 0.0116, -0.0408,  0.0774, -0.1025, -0.1783],
          [ 0.1765, -0.1064, -0.1461,  0.0860,  0.0039],
          [-0.0551, -0.0320,  0.0532, -0.0083, -0.0332],
          [-0.1784, -0.0250, -0.0903, -0.0074,  0.0160],
          [ 0.1507, -0.0781,  0.1486, -0.0808, -0.1729]]]], device='cuda:0',
       requires_grad=True)), ('bias', Parameter containing:
tensor([-0.0167,  0.0863, -0.0232,  0.0925,  0.1978, -0.0268], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[]

剪枝一个模块#

要剪枝一个模块(在本例中,是我们的 LeNet 架构的 conv1 层),首先在 torch.nn.utils.prune 中可用的技术中选择一种剪枝技术(或 实现 您自己的方法,通过继承 BasePruningMethod)。然后,指定要在该模块中剪枝的模块和参数的名称。最后,使用所选剪枝技术所需的适当关键字参数,指定剪枝参数。

在本例中,我们将随机剪枝 conv1 层中名为 weight 的参数的 30% 的连接。将模块作为函数的第一个参数传递;name 使用其字符串标识符标识该模块内的参数;而 amount 指示要剪枝的连接百分比(如果它是在 0. 和 1. 之间的浮点数),或者要剪枝的绝对连接数(如果它是非负整数)。

prune.random_unstructured(module, name="weight", amount=0.3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

剪枝通过从参数中删除 weight 并将其替换为名为 weight_orig 的新参数(即,将 "_orig" 附加到初始参数 name)来起作用。 weight_orig 存储未剪枝的张量版本。 bias 未被剪枝,因此它将保持不变。

print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([-0.0167,  0.0863, -0.0232,  0.0925,  0.1978, -0.0268], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.0887,  0.0490,  0.0977,  0.0485, -0.1110],
          [ 0.0609,  0.1650,  0.0142,  0.0328, -0.1885],
          [ 0.1201, -0.1949, -0.0078,  0.1947, -0.0317],
          [-0.1072, -0.0950,  0.0800,  0.1870, -0.1998],
          [-0.1034,  0.0637,  0.1004,  0.1067,  0.1966]]],


        [[[-0.1798, -0.0160,  0.1653, -0.1746, -0.1187],
          [-0.0332, -0.1939, -0.0348,  0.0954,  0.0174],
          [ 0.0709,  0.0351, -0.0204, -0.1408,  0.1956],
          [-0.1881, -0.1463, -0.0860, -0.0436, -0.0996],
          [ 0.1238,  0.1071, -0.1757,  0.1637, -0.1336]]],


        [[[ 0.1670, -0.0898,  0.0706,  0.1766,  0.0581],
          [ 0.0269, -0.0498,  0.0255,  0.0530, -0.1791],
          [-0.1545, -0.1988,  0.1100, -0.0505, -0.1935],
          [ 0.0711,  0.1246, -0.0768, -0.1010, -0.0196],
          [ 0.0227,  0.0828, -0.1782,  0.0359,  0.1211]]],


        [[[-0.1681,  0.1553,  0.0715,  0.1201,  0.0706],
          [ 0.1194, -0.0622, -0.0143, -0.1870, -0.0589],
          [ 0.1232,  0.1836,  0.0275, -0.0334, -0.1969],
          [ 0.1895, -0.0932, -0.1651,  0.0179, -0.0509],
          [-0.0377,  0.0583,  0.1905, -0.0031, -0.0729]]],


        [[[ 0.1916,  0.0930,  0.1735,  0.0792,  0.0977],
          [ 0.0440, -0.1013,  0.1709, -0.0262, -0.1991],
          [-0.1626,  0.1974, -0.0097,  0.1050,  0.0525],
          [ 0.1470,  0.0845, -0.1195,  0.1524,  0.0614],
          [ 0.1687, -0.0960,  0.1612, -0.1088,  0.0619]]],


        [[[ 0.0116, -0.0408,  0.0774, -0.1025, -0.1783],
          [ 0.1765, -0.1064, -0.1461,  0.0860,  0.0039],
          [-0.0551, -0.0320,  0.0532, -0.0083, -0.0332],
          [-0.1784, -0.0250, -0.0903, -0.0074,  0.0160],
          [ 0.1507, -0.0781,  0.1486, -0.0808, -0.1729]]]], device='cuda:0',
       requires_grad=True))]

剪枝技术选定的剪枝掩码被保存为名为 weight_mask 的模块缓冲区(即,将 "_mask" 附加到初始参数 name)。

print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 0., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[1., 0., 0., 1., 0.],
          [1., 1., 0., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 0., 1., 1., 1.],
          [0., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 0., 1., 1., 0.],
          [1., 1., 0., 1., 1.]]],


        [[[1., 0., 1., 1., 1.],
          [0., 1., 1., 0., 0.],
          [1., 0., 1., 1., 0.],
          [0., 1., 0., 1., 1.],
          [0., 1., 1., 1., 1.]]],


        [[[0., 1., 0., 1., 0.],
          [1., 1., 0., 1., 0.],
          [0., 0., 0., 0., 0.],
          [1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1.]]]], device='cuda:0'))]

为了使前向传递无需修改即可工作,weight 属性需要存在。 torch.nn.utils.prune 中实现的剪枝技术通过将掩码与原始参数组合来计算剪枝后的权重,并将其存储在 weight 属性中。请注意,这不再是 module 的一个参数,现在只是一个属性。

tensor([[[[-0.0887,  0.0490,  0.0977,  0.0000, -0.1110],
          [ 0.0609,  0.1650,  0.0142,  0.0328, -0.0000],
          [ 0.1201, -0.1949, -0.0078,  0.1947, -0.0317],
          [-0.0000, -0.0950,  0.0000,  0.1870, -0.0000],
          [-0.1034,  0.0000,  0.1004,  0.1067,  0.1966]]],


        [[[-0.1798, -0.0000,  0.0000, -0.1746, -0.0000],
          [-0.0332, -0.1939, -0.0000,  0.0954,  0.0000],
          [ 0.0709,  0.0351, -0.0204, -0.1408,  0.1956],
          [-0.1881, -0.0000, -0.0860, -0.0436, -0.0000],
          [ 0.1238,  0.0000, -0.1757,  0.1637, -0.1336]]],


        [[[ 0.0000, -0.0898,  0.0706,  0.1766,  0.0581],
          [ 0.0269, -0.0498,  0.0255,  0.0530, -0.0000],
          [-0.1545, -0.1988,  0.1100, -0.0505, -0.1935],
          [ 0.0711,  0.1246, -0.0768, -0.1010, -0.0196],
          [ 0.0227,  0.0828, -0.1782,  0.0359,  0.1211]]],


        [[[-0.1681,  0.0000,  0.0715,  0.1201,  0.0706],
          [ 0.0000, -0.0622, -0.0000, -0.1870, -0.0589],
          [ 0.1232,  0.1836,  0.0000, -0.0334, -0.1969],
          [ 0.1895, -0.0000, -0.1651,  0.0179, -0.0000],
          [-0.0377,  0.0583,  0.0000, -0.0031, -0.0729]]],


        [[[ 0.1916,  0.0000,  0.1735,  0.0792,  0.0977],
          [ 0.0000, -0.1013,  0.1709, -0.0000, -0.0000],
          [-0.1626,  0.0000, -0.0097,  0.1050,  0.0000],
          [ 0.0000,  0.0845, -0.0000,  0.1524,  0.0614],
          [ 0.0000, -0.0960,  0.1612, -0.1088,  0.0619]]],


        [[[ 0.0000, -0.0408,  0.0000, -0.1025, -0.0000],
          [ 0.1765, -0.1064, -0.0000,  0.0860,  0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.1784, -0.0000, -0.0903, -0.0074,  0.0160],
          [ 0.1507, -0.0000,  0.0000, -0.0808, -0.1729]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

最后,剪枝在每次前向传递之前使用 PyTorch 的 forward_pre_hooks 应用。具体来说,当 module 被剪枝时,如我们在此处所做的那样,它将为与其关联的每个参数获取一个 forward_pre_hook。在这种情况下,由于到目前为止我们只剪枝了名为 weight 的原始参数,因此只会存在一个钩子。

print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fe892269690>)])

为了完整起见,我们现在也可以剪枝 bias,看看 module 的参数、缓冲区、钩子和属性如何变化。只是为了尝试另一种剪枝技术,我们使用 L1 范数剪枝偏差中最小的 3 个条目,如 l1_unstructured 剪枝函数中实现的那样。

prune.l1_unstructured(module, name="bias", amount=3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

现在我们期望命名的参数包括 weight_orig(来自之前)和 bias_orig。缓冲区将包括 weight_maskbias_mask。这两个张量的剪枝版本将作为模块属性存在,并且该模块现在将有两个 forward_pre_hooks

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[-0.0887,  0.0490,  0.0977,  0.0485, -0.1110],
          [ 0.0609,  0.1650,  0.0142,  0.0328, -0.1885],
          [ 0.1201, -0.1949, -0.0078,  0.1947, -0.0317],
          [-0.1072, -0.0950,  0.0800,  0.1870, -0.1998],
          [-0.1034,  0.0637,  0.1004,  0.1067,  0.1966]]],


        [[[-0.1798, -0.0160,  0.1653, -0.1746, -0.1187],
          [-0.0332, -0.1939, -0.0348,  0.0954,  0.0174],
          [ 0.0709,  0.0351, -0.0204, -0.1408,  0.1956],
          [-0.1881, -0.1463, -0.0860, -0.0436, -0.0996],
          [ 0.1238,  0.1071, -0.1757,  0.1637, -0.1336]]],


        [[[ 0.1670, -0.0898,  0.0706,  0.1766,  0.0581],
          [ 0.0269, -0.0498,  0.0255,  0.0530, -0.1791],
          [-0.1545, -0.1988,  0.1100, -0.0505, -0.1935],
          [ 0.0711,  0.1246, -0.0768, -0.1010, -0.0196],
          [ 0.0227,  0.0828, -0.1782,  0.0359,  0.1211]]],


        [[[-0.1681,  0.1553,  0.0715,  0.1201,  0.0706],
          [ 0.1194, -0.0622, -0.0143, -0.1870, -0.0589],
          [ 0.1232,  0.1836,  0.0275, -0.0334, -0.1969],
          [ 0.1895, -0.0932, -0.1651,  0.0179, -0.0509],
          [-0.0377,  0.0583,  0.1905, -0.0031, -0.0729]]],


        [[[ 0.1916,  0.0930,  0.1735,  0.0792,  0.0977],
          [ 0.0440, -0.1013,  0.1709, -0.0262, -0.1991],
          [-0.1626,  0.1974, -0.0097,  0.1050,  0.0525],
          [ 0.1470,  0.0845, -0.1195,  0.1524,  0.0614],
          [ 0.1687, -0.0960,  0.1612, -0.1088,  0.0619]]],


        [[[ 0.0116, -0.0408,  0.0774, -0.1025, -0.1783],
          [ 0.1765, -0.1064, -0.1461,  0.0860,  0.0039],
          [-0.0551, -0.0320,  0.0532, -0.0083, -0.0332],
          [-0.1784, -0.0250, -0.0903, -0.0074,  0.0160],
          [ 0.1507, -0.0781,  0.1486, -0.0808, -0.1729]]]], device='cuda:0',
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0167,  0.0863, -0.0232,  0.0925,  0.1978, -0.0268], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 0., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[1., 0., 0., 1., 0.],
          [1., 1., 0., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 0., 1., 1., 1.],
          [0., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 0., 1., 1., 0.],
          [1., 1., 0., 1., 1.]]],


        [[[1., 0., 1., 1., 1.],
          [0., 1., 1., 0., 0.],
          [1., 0., 1., 1., 0.],
          [0., 1., 0., 1., 1.],
          [0., 1., 1., 1., 1.]]],


        [[[0., 1., 0., 1., 0.],
          [1., 1., 0., 1., 0.],
          [0., 0., 0., 0., 0.],
          [1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([0., 1., 0., 1., 1., 0.], device='cuda:0'))]
print(module.bias)
tensor([-0.0000, 0.0863, -0.0000, 0.0925, 0.1978, -0.0000], device='cuda:0',
       grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fe892269690>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7fe8922a0eb0>)])

迭代剪枝#

模块中的同一个参数可以多次剪枝,各种剪枝调用的效果等于按顺序应用各种掩码的组合。 PruningContainercompute_mask 方法处理新掩码与旧掩码的组合。

例如,如果我们现在想进一步剪枝 module.weight,这次使用沿张量第 0 轴的结构化剪枝(第 0 轴对应于卷积层的输出通道,对于 conv1 而言,其维度为 6),基于通道的 L2 范数。这可以使用 ln_structured 函数实现,n=2dim=0

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)
tensor([[[[-0.0887,  0.0490,  0.0977,  0.0000, -0.1110],
          [ 0.0609,  0.1650,  0.0142,  0.0328, -0.0000],
          [ 0.1201, -0.1949, -0.0078,  0.1947, -0.0317],
          [-0.0000, -0.0950,  0.0000,  0.1870, -0.0000],
          [-0.1034,  0.0000,  0.1004,  0.1067,  0.1966]]],


        [[[-0.1798, -0.0000,  0.0000, -0.1746, -0.0000],
          [-0.0332, -0.1939, -0.0000,  0.0954,  0.0000],
          [ 0.0709,  0.0351, -0.0204, -0.1408,  0.1956],
          [-0.1881, -0.0000, -0.0860, -0.0436, -0.0000],
          [ 0.1238,  0.0000, -0.1757,  0.1637, -0.1336]]],


        [[[ 0.0000, -0.0898,  0.0706,  0.1766,  0.0581],
          [ 0.0269, -0.0498,  0.0255,  0.0530, -0.0000],
          [-0.1545, -0.1988,  0.1100, -0.0505, -0.1935],
          [ 0.0711,  0.1246, -0.0768, -0.1010, -0.0196],
          [ 0.0227,  0.0828, -0.1782,  0.0359,  0.1211]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

相应的钩子现在将是 torch.nn.utils.prune.PruningContainer 类型,并将存储应用于 weight 参数的剪枝历史记录。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container
[<torch.nn.utils.prune.RandomUnstructured object at 0x7fe892269690>, <torch.nn.utils.prune.LnStructured object at 0x7fe8922a0c10>]

序列化剪枝模型#

所有相关张量,包括掩码缓冲区和用于计算剪枝张量的原始参数都存储在模型的 state_dict 中,因此如果需要,可以轻松地序列化和保存它们。

print(model.state_dict().keys())
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

删除剪枝重新参数化#

为了使剪枝永久生效,删除 weight_origweight_mask 方面的重新参数化,并删除 forward_pre_hook,我们可以使用 torch.nn.utils.prune 中的 remove 功能。请注意,这不会撤消剪枝,就好像它从未发生过一样。相反,它会使其永久生效,而是将剪枝版本的参数 weight 重新分配给模型参数。

在删除重新参数化之前

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[-0.0887,  0.0490,  0.0977,  0.0485, -0.1110],
          [ 0.0609,  0.1650,  0.0142,  0.0328, -0.1885],
          [ 0.1201, -0.1949, -0.0078,  0.1947, -0.0317],
          [-0.1072, -0.0950,  0.0800,  0.1870, -0.1998],
          [-0.1034,  0.0637,  0.1004,  0.1067,  0.1966]]],


        [[[-0.1798, -0.0160,  0.1653, -0.1746, -0.1187],
          [-0.0332, -0.1939, -0.0348,  0.0954,  0.0174],
          [ 0.0709,  0.0351, -0.0204, -0.1408,  0.1956],
          [-0.1881, -0.1463, -0.0860, -0.0436, -0.0996],
          [ 0.1238,  0.1071, -0.1757,  0.1637, -0.1336]]],


        [[[ 0.1670, -0.0898,  0.0706,  0.1766,  0.0581],
          [ 0.0269, -0.0498,  0.0255,  0.0530, -0.1791],
          [-0.1545, -0.1988,  0.1100, -0.0505, -0.1935],
          [ 0.0711,  0.1246, -0.0768, -0.1010, -0.0196],
          [ 0.0227,  0.0828, -0.1782,  0.0359,  0.1211]]],


        [[[-0.1681,  0.1553,  0.0715,  0.1201,  0.0706],
          [ 0.1194, -0.0622, -0.0143, -0.1870, -0.0589],
          [ 0.1232,  0.1836,  0.0275, -0.0334, -0.1969],
          [ 0.1895, -0.0932, -0.1651,  0.0179, -0.0509],
          [-0.0377,  0.0583,  0.1905, -0.0031, -0.0729]]],


        [[[ 0.1916,  0.0930,  0.1735,  0.0792,  0.0977],
          [ 0.0440, -0.1013,  0.1709, -0.0262, -0.1991],
          [-0.1626,  0.1974, -0.0097,  0.1050,  0.0525],
          [ 0.1470,  0.0845, -0.1195,  0.1524,  0.0614],
          [ 0.1687, -0.0960,  0.1612, -0.1088,  0.0619]]],


        [[[ 0.0116, -0.0408,  0.0774, -0.1025, -0.1783],
          [ 0.1765, -0.1064, -0.1461,  0.0860,  0.0039],
          [-0.0551, -0.0320,  0.0532, -0.0083, -0.0332],
          [-0.1784, -0.0250, -0.0903, -0.0074,  0.0160],
          [ 0.1507, -0.0781,  0.1486, -0.0808, -0.1729]]]], device='cuda:0',
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0167,  0.0863, -0.0232,  0.0925,  0.1978, -0.0268], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 0., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[1., 0., 0., 1., 0.],
          [1., 1., 0., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 1., 0., 1., 1., 0.], device='cuda:0'))]
tensor([[[[-0.0887,  0.0490,  0.0977,  0.0000, -0.1110],
          [ 0.0609,  0.1650,  0.0142,  0.0328, -0.0000],
          [ 0.1201, -0.1949, -0.0078,  0.1947, -0.0317],
          [-0.0000, -0.0950,  0.0000,  0.1870, -0.0000],
          [-0.1034,  0.0000,  0.1004,  0.1067,  0.1966]]],


        [[[-0.1798, -0.0000,  0.0000, -0.1746, -0.0000],
          [-0.0332, -0.1939, -0.0000,  0.0954,  0.0000],
          [ 0.0709,  0.0351, -0.0204, -0.1408,  0.1956],
          [-0.1881, -0.0000, -0.0860, -0.0436, -0.0000],
          [ 0.1238,  0.0000, -0.1757,  0.1637, -0.1336]]],


        [[[ 0.0000, -0.0898,  0.0706,  0.1766,  0.0581],
          [ 0.0269, -0.0498,  0.0255,  0.0530, -0.0000],
          [-0.1545, -0.1988,  0.1100, -0.0505, -0.1935],
          [ 0.0711,  0.1246, -0.0768, -0.1010, -0.0196],
          [ 0.0227,  0.0828, -0.1782,  0.0359,  0.1211]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

删除重新参数化之后

prune.remove(module, 'weight')
print(list(module.named_parameters()))
[('bias_orig', Parameter containing:
tensor([-0.0167,  0.0863, -0.0232,  0.0925,  0.1978, -0.0268], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0887,  0.0490,  0.0977,  0.0000, -0.1110],
          [ 0.0609,  0.1650,  0.0142,  0.0328, -0.0000],
          [ 0.1201, -0.1949, -0.0078,  0.1947, -0.0317],
          [-0.0000, -0.0950,  0.0000,  0.1870, -0.0000],
          [-0.1034,  0.0000,  0.1004,  0.1067,  0.1966]]],


        [[[-0.1798, -0.0000,  0.0000, -0.1746, -0.0000],
          [-0.0332, -0.1939, -0.0000,  0.0954,  0.0000],
          [ 0.0709,  0.0351, -0.0204, -0.1408,  0.1956],
          [-0.1881, -0.0000, -0.0860, -0.0436, -0.0000],
          [ 0.1238,  0.0000, -0.1757,  0.1637, -0.1336]]],


        [[[ 0.0000, -0.0898,  0.0706,  0.1766,  0.0581],
          [ 0.0269, -0.0498,  0.0255,  0.0530, -0.0000],
          [-0.1545, -0.1988,  0.1100, -0.0505, -0.1935],
          [ 0.0711,  0.1246, -0.0768, -0.1010, -0.0196],
          [ 0.0227,  0.0828, -0.1782,  0.0359,  0.1211]]],


        [[[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000]]]], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('bias_mask', tensor([0., 1., 0., 1., 1., 0.], device='cuda:0'))]

剪枝模型中的多个参数#

通过指定所需的剪枝技术和参数,我们可以轻松地剪枝网络中的多个张量,也许根据它们的类型,如我们在此示例中所见。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全局剪枝#

到目前为止,我们只关注通常所说的“局部”剪枝,即逐个剪枝模型中的张量,通过将每个条目的统计信息(权重大小、激活、梯度等)仅与该张量中的其他条目进行比较。但是,一种常见且可能更强大的技术是立即剪枝整个模型,例如删除整个模型中最低的 20% 的连接,而不是删除每个层中最低的 20% 的连接。这可能会导致每层的不同剪枝百分比。让我们看看如何使用 torch.nn.utils.prune 中的 global_unstructured 来实现这一点。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

现在我们可以检查每个剪枝参数中诱导的稀疏性,它不会在每个层中都等于 20%。但是,全局稀疏性将是(大约)20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
Sparsity in conv1.weight: 6.67%
Sparsity in conv2.weight: 13.29%
Sparsity in fc1.weight: 22.22%
Sparsity in fc2.weight: 12.14%
Sparsity in fc3.weight: 8.93%
Global sparsity: 20.00%

扩展 torch.nn.utils.prune 以使用自定义剪枝函数#

为了实现您自己的剪枝函数,您可以像所有其他剪枝方法一样,通过继承 nn.utils.prune 模块中的 BasePruningMethod 基类来扩展它。基类为您实现了以下方法:__call__apply_maskapplypruneremove。在一些特殊情况下,您不需要重新实现这些方法来用于您的新剪枝技术。但是,您需要实现 __init__(构造函数)和 compute_mask(根据剪枝技术的逻辑,计算给定张量的掩码的指令)。此外,您还需要指定此技术实现哪种类型的剪枝(支持的选项是 globalstructuredunstructured)。这用于确定如何在迭代应用剪枝的情况下组合掩码。换句话说,当剪枝一个预剪枝的参数时,当前的剪枝技术预计会作用于该参数的未剪枝部分。指定 PRUNING_TYPE 将使 PruningContainer(处理剪枝掩码的迭代应用)能够正确识别要剪枝的参数切片。

例如,假设您想要实现一种剪枝技术,该技术剪枝张量中的每个其他条目(或者——如果张量之前已经被剪枝——在剩余的未剪枝张量部分中)。这将是 PRUNING_TYPE='unstructured',因为它作用于层中的单个连接,而不是整个单元/通道 ('structured'),或者跨越不同的参数 ('global')。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

现在,为了将此应用于 nn.Module 中的一个参数,您还应该提供一个简单的函数来实例化该方法并应用它。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

让我们尝试一下!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

脚本总运行时间: (0 分钟 0.541 秒)