MarlinSparseLayout¶
- class torchao.dtypes.MarlinSparseLayout[source]¶
MarlinSparseLayout 是一个用于处理稀疏张量格式的布局类,它专门为 Marlin 稀疏内核设计。该布局用于优化具有 2:4 稀疏模式的仿射量化张量的存储和计算。
该布局确保张量数据以与 Marlin 稀疏内核操作兼容的格式进行预处理和存储。它提供了预处理输入张量和管理量化张量布局的方法。
- pre_process(input: Tensor) Tensor [source]¶
- 预处理输入张量,使其符合 Marlin 稀疏内核的正确格式。
1. 输入张量被转置,因为线性层以转置格式保留权重。
2. 张量注入 2:4 稀疏性。
3. 再次转置,因为量化过程将计算 dim=-1 的尺度。
- 参数:
input (torch.Tensor) – 要预处理的输入张量。
- 返回:
预处理后的张量。
- 返回类型: