WandaSparsifier¶
- class torchao.sparsity.WandaSparsifier(sparsity_level: float = 0.5, semi_structured_block_size: Optional[int] = None)[source]¶
Wanda 稀疏器
Wanda (Pruning by Weights and activations),在其论文 https://arxiv.org/abs/2306.11695 中提出,是一种感知激活的剪枝方法。该稀疏器根据输入激活范数与权重幅度的乘积来移除权重。
此稀疏器由三个变量控制:1. sparsity_level 定义了被归零的稀疏块的数量;
- 参数:
sparsity_level – 目标稀疏级别;
model – 要进行稀疏化的模型;
- prepare(model: Module, config: List[Dict]) None [source]¶
准备模型,通过添加参数化。
注意
The model is modified inplace. If you need to preserve the original model, use copy.deepcopy.
- squash_mask(params_to_keep: Optional[Tuple[str, ...]] = None, params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, *args, **kwargs)[source]¶
压缩稀疏掩码到相应的张量。
如果设置了 params_to_keep 或 params_to_keep_per_layer,则模块将附加一个 sparse_params 字典。
- 参数:
params_to_keep – 要保存在模块中的键列表,或者一个代表模块和将保存稀疏参数的键的字典
params_to_keep_per_layer – 用于指定要为特定层保存的参数的字典。字典中的键应该是模块的 FQN(完全限定名),而值应该是字符串列表,表示要在 sparse_params 中保存的变量名称
示例
>>> # xdoctest: +SKIP("locals are undefined") >>> # Don't save any sparse params >>> sparsifier.squash_mask() >>> hasattr(model.submodule1, 'sparse_params') False
>>> # Keep sparse params per layer >>> sparsifier.squash_mask( ... params_to_keep_per_layer={ ... 'submodule1.linear1': ('foo', 'bar'), ... 'submodule2.linear42': ('baz',) ... }) >>> print(model.submodule1.linear1.sparse_params) {'foo': 42, 'bar': 24} >>> print(model.submodule2.linear42.sparse_params) {'baz': 0.1}
>>> # Keep sparse params for all layers >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar')) >>> print(model.submodule1.linear1.sparse_params) {'foo': 42, 'bar': 24} >>> print(model.submodule2.linear42.sparse_params) {'foo': 42, 'bar': 24}
>>> # Keep some sparse params for all layers, and specific ones for >>> # some other layers >>> sparsifier.squash_mask( ... params_to_keep=('foo', 'bar'), ... params_to_keep_per_layer={ ... 'submodule2.linear42': ('baz',) ... }) >>> print(model.submodule1.linear1.sparse_params) {'foo': 42, 'bar': 24} >>> print(model.submodule2.linear42.sparse_params) {'foo': 42, 'bar': 24, 'baz': 0.1}