评价此页

torch.nn.utils.prune.ln_structured#

torch.nn.utils.prune.ln_structured(module, name, amount, n, dim, importance_scores=None)[source]#

通过移除沿指定维度具有最低 Ln-范数的通道来修剪张量。

通过移除沿指定dim的最低 Ln-范数的(当前未修剪的)通道的指定amount,来修剪module中名为name的参数。通过以下方式修改module(并返回修改后的module):

  1. 添加一个名为name+'_mask'的命名缓冲区,该缓冲区对应于应用于参数name的二值掩码。

  2. 用其修剪后的版本替换参数name,同时将原始(未修剪)参数存储在名为name+'_orig'的新参数中。

参数
  • module (nn.Module) – 包含要修剪张量的模块

  • name (str) – module 中将要执行修剪的参数名。

  • amount (intfloat) – 要修剪的参数数量。如果是float,则应介于 0.0 和 1.0 之间,表示要修剪的参数的比例。如果是int,则表示要修剪的参数的绝对数量。

  • n (int, float, inf, -inf, 'fro', 'nuc') – 请参阅torch.norm() 中有效条目的文档。

  • dim (int) – 定义要修剪通道的维度索引。

  • importance_scores (torch.Tensor) – 重要性得分张量(形状与要修剪的模块参数相同),用于计算修剪的掩码。此张量中的值表示被修剪参数中相应元素的 গুরুত্ব。如果未指定或为 None,则将使用模块参数。

返回

模块的修改(即剪枝)后的版本

返回类型

module (nn.Module)

示例

>>> from torch.nn.utils import prune
>>> m = prune.ln_structured(
...     nn.Conv2d(5, 3, 2), "weight", amount=0.3, dim=1, n=float("-inf")
... )