torch.nn.utils.prune.custom_from_mask#
- torch.nn.utils.prune.custom_from_mask(module, name, mask)[source]#
Prune tensor corresponding to parameter called
name
inmodule
by applying the pre-computed mask inmask
.Modifies module in place (and also return the modified module) by
adding a named buffer called
name+'_mask'
corresponding to the binary mask applied to the parametername
by the pruning method.replacing the parameter
name
by its pruned version, while the original (unpruned) parameter is stored in a new parameter namedname+'_orig'
.
- 参数
- 返回
模块的修改(即剪枝)后的版本
- 返回类型
module (nn.Module)
示例
>>> from torch.nn.utils import prune >>> m = prune.custom_from_mask( ... nn.Linear(5, 3), name="bias", mask=torch.tensor([0, 1, 0]) ... ) >>> print(m.bias_mask) tensor([0., 1., 0.])