评价此页

torch.nn.utils.prune.identity#

torch.nn.utils.prune.identity(module, name)[source]#

Apply pruning reparametrization without pruning any units.

Applies pruning reparametrization to the tensor corresponding to the parameter called name in module without actually pruning any units. Modifies module in place (and also return the modified module) by

  1. adding a named buffer called name+'_mask' corresponding to the binary mask applied to the parameter name by the pruning method.

  2. replacing the parameter name by its pruned version, while the original (unpruned) parameter is stored in a new parameter named name+'_orig'.

注意

The mask is a tensor of ones.

参数
  • module (nn.Module) – module containing the tensor to prune.

  • name (str) – parameter name within module on which pruning will act.

返回

模块的修改(即剪枝)后的版本

返回类型

module (nn.Module)

示例

>>> m = prune.identity(nn.Linear(2, 3), "bias")
>>> print(m.bias_mask)
tensor([1., 1., 1.])