PruningContainer#
- class torch.nn.utils.prune.PruningContainer(*args)[source]#
用于迭代剪枝的剪枝方法序列的容器。
跟踪应用剪枝方法的顺序,并处理连续剪枝调用的合并。
接受 BasePruningMethod 的实例或其可迭代对象作为参数。
- add_pruning_method(method)[source]#
向容器添加一个子剪枝方法
method
。- 参数
method (BasePruningMethod 的子类) – 要添加到容器的子剪枝方法。
- classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source]#
Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.
- apply_mask(module)[source]#
Simply handles the multiplication between the parameter being pruned and the generated mask.
Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.
- 参数
module (nn.Module) – module containing the tensor to prune
- 返回
pruned version of the input tensor
- 返回类型
pruned_tensor (torch.Tensor)
- compute_mask(t, default_mask)[source]#
通过计算新的部分掩码来应用最新的
method
,并将其与default_mask
组合返回。新的部分掩码应在未被
default_mask
清零的条目或通道上计算。新掩码将根据PRUNING_TYPE
(由类型处理程序处理)从张量t
的哪些部分计算,这取决于。对于“非结构化”,掩码将从非零条目的展平列表中计算;
对于“结构化”,掩码将从张量中的非零通道计算;
对于“全局”,掩码将在所有条目中计算。
- 参数
t (torch.Tensor) – 表示要剪枝的参数的张量(与
default_mask
的维度相同)。default_mask (torch.Tensor) – 前一个剪枝迭代的掩码。
- 返回
组合了
default_mask
和当前剪枝method
的新掩码的效果(与default_mask
和t
的维度相同)。- 返回类型
mask (torch.Tensor)
- prune(t, default_mask=None, importance_scores=None)[source]#
Compute and returns a pruned version of input tensor
t
.根据
compute_mask()
中指定的剪枝规则。- 参数
t (torch.Tensor) – 要剪枝的张量(维度与
default_mask
相同)。importance_scores (torch.Tensor) – 重要性分数张量(与
t
形状相同),用于计算剪枝t
的掩码。此张量中的值指示正在剪枝的t
中相应元素的 গুরুত্ব。如果未指定或为 None,则将使用张量t
本身。default_mask (torch.Tensor, optional) – 前一个剪枝迭代的掩码(如果有)。在确定剪枝应作用于张量的哪个部分时需要考虑。如果为 None,则默认为一个全为 1 的掩码。
- 返回
张量
t
的修剪版本。