LnStructured#
- class torch.nn.utils.prune.LnStructured(amount, n, dim=-1)[source]#
根据 L
n
范数修剪张量中的整个(当前未修剪的)通道。- 参数
- classmethod apply(module, name, amount, n, dim, importance_scores=None)[source]#
Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.
- 参数
module (nn.Module) – module containing the tensor to prune
name (str) – 在
module
中执行剪枝操作的参数名称。amount (int 或 float) – 要剪枝的参数数量。如果是
float
,则应介于 0.0 和 1.0 之间,表示要剪枝的参数的比例。如果是int
,则表示要剪枝的参数的绝对数量。n (int, float, inf, -inf, 'fro', 'nuc') – 请参阅
torch.norm()
中参数p
的有效条目文档。dim (int) – 定义要修剪通道的维度的索引。
importance_scores (torch.Tensor) – 用于计算修剪掩码的重要性分数张量(形状与模块参数相同)。此张量中的值表示要修剪的参数中相应元素的重要性。如果未指定或为 None,则将使用模块参数本身。
- apply_mask(module)[source]#
Simply handles the multiplication between the parameter being pruned and the generated mask.
Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.
- 参数
module (nn.Module) – module containing the tensor to prune
- 返回
pruned version of the input tensor
- 返回类型
pruned_tensor (torch.Tensor)
- compute_mask(t, default_mask)[source]#
计算并返回输入张量
t
的掩码。从基础
default_mask
(如果张量尚未被修剪,则应为全 1 的掩码)开始,生成一个掩码,通过将具有最低 Ln
-范数的通道沿指定维度置零,来应用于default_mask
之上。- 参数
t (torch.Tensor) – 表示要修剪的参数的张量
default_mask (torch.Tensor) – 来自先前修剪迭代的基础掩码,在应用新掩码后需要保留。与
t
的维度相同。
- 返回
应用于
t
的掩码,维度与t
相同- 返回类型
mask (torch.Tensor)
- 引发
IndexError – 如果
self.dim >= len(t.shape)
- prune(t, default_mask=None, importance_scores=None)[source]#
Compute and returns a pruned version of input tensor
t
.根据
compute_mask()
中指定的修剪规则进行操作。- 参数
t (torch.Tensor) – 要剪枝的张量(维度与
default_mask
相同)。importance_scores (torch.Tensor) – 重要性分数张量(与
t
形状相同),用于计算剪枝t
的掩码。此张量中的值指示正在剪枝的t
中相应元素的 গুরুত্ব。如果未指定或为 None,则将使用张量t
本身。default_mask (torch.Tensor, optional) – 前一个剪枝迭代的掩码(如果有)。在确定剪枝应作用于张量的哪个部分时需要考虑。如果为 None,则默认为一个全为 1 的掩码。
- 返回
张量
t
的修剪版本。