评价此页

BasePruningMethod#

class torch.nn.utils.prune.BasePruningMethod[source]#

用于创建新剪枝技术的抽象基类。

提供了一个用于自定义的骨架,需要重写诸如 compute_mask()apply() 等方法。

classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source]#

Add pruning on the fly and reparametrization of a tensor.

Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.

参数
  • module (nn.Module) – module containing the tensor to prune

  • name (str) – 在 module 中执行剪枝操作的参数名称。

  • args – 传递给 BasePruningMethod 子类的参数

  • importance_scores (torch.Tensor) – 重要性得分张量(形状与模块参数相同),用于计算剪枝的掩码。此张量中的值表示正在剪枝的参数中相应元素的重要性。如果未指定或为 None,则将使用参数本身。

  • kwargs – 传递给 BasePruningMethod 子类的关键字参数

apply_mask(module)[source]#

Simply handles the multiplication between the parameter being pruned and the generated mask.

Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.

参数

module (nn.Module) – module containing the tensor to prune

返回

pruned version of the input tensor

返回类型

pruned_tensor (torch.Tensor)

abstract compute_mask(t, default_mask)[source]#

计算并返回输入张量 t 的掩码。

从一个基础的 default_mask(如果张量尚未被剪枝,则应为全为 1 的掩码)开始,根据特定的剪枝方法规则生成一个随机掩码,以应用于 default_mask 之上。

参数
  • t (torch.Tensor) – 表示待剪枝参数重要性得分的张量

  • prune. (parameter to) –

  • default_mask (torch.Tensor) – 上一次迭代剪枝的基础掩码

  • iterations

  • is (that need to be respected after the new mask) –

  • t. (applied. Same dims as) –

返回

应用于 t 的掩码,维度与 t 相同

返回类型

mask (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source]#

Compute and returns a pruned version of input tensor t.

根据 compute_mask() 中指定的剪枝规则。

参数
  • t (torch.Tensor) – 要剪枝的张量(维度与 default_mask 相同)。

  • importance_scores (torch.Tensor) – 重要性分数张量(与 t 形状相同),用于计算剪枝 t 的掩码。此张量中的值指示正在剪枝的 t 中相应元素的 গুরুত্ব。如果未指定或为 None,则将使用张量 t 本身。

  • default_mask (torch.Tensor, optional) – 前一个剪枝迭代的掩码(如果有)。在确定剪枝应作用于张量的哪个部分时需要考虑。如果为 None,则默认为一个全为 1 的掩码。

返回

张量 t 的修剪版本。

remove(module)[source]#

从模块中移除修剪重参数化。

名为 name 的已剪枝参数将永久保持剪枝状态,名为 name+'_orig' 的参数将从参数列表中移除。类似地,名为 name+'_mask' 的缓冲区也将从缓冲区中移除。

注意

修剪本身**不会**被撤销或恢复!