评价此页

L1Unstructured#

class torch.nn.utils.prune.L1Unstructured(amount)[source]#

Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.

参数

amount (int or float) – quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune.

classmethod apply(module, name, amount, importance_scores=None)[source]#

Add pruning on the fly and reparametrization of a tensor.

Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.

参数
  • module (nn.Module) – module containing the tensor to prune

  • name (str) – parameter name within module on which pruning will act.

  • amount (int or float) – quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune.

  • importance_scores (torch.Tensor) – tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place.

apply_mask(module)[source]#

Simply handles the multiplication between the parameter being pruned and the generated mask.

Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.

参数

module (nn.Module) – module containing the tensor to prune

返回

pruned version of the input tensor

返回类型

pruned_tensor (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source]#

Compute and returns a pruned version of input tensor t.

According to the pruning rule specified in compute_mask().

参数
  • t (torch.Tensor) – 要修剪的张量(与 default_mask 的维度相同)。

  • importance_scores (torch.Tensor) – 重要性分数张量(与 t 的形状相同),用于计算修剪 t 的掩码。此张量中的值指示正在被修剪的 t 中相应元素的重要性。如果未指定或为 None,则将使用 t 张量代替。

  • default_mask (torch.Tensor, optional) – 来自上一个修剪迭代的掩码(如果有)。在确定修剪应作用于张量的哪个部分时会考虑此掩码。如果为 None,则默认为一个全 1 的掩码。

返回

张量 t 的修剪版本。

remove(module)[源代码]#

从模块中移除修剪重参数化。

已修剪的名为 name 的参数将永久保留修剪状态,名为 name+'_orig' 的参数将从参数列表中移除。同样,名为 name+'_mask' 的缓冲区将从缓冲区中移除。

注意

修剪本身**不会**被撤销或恢复!