MarlinSparseLayout¶
- class torchao.dtypes.MarlinSparseLayout[source]¶
MarlinSparseLayout 是一个布局类,用于处理专为 Marlin 稀疏内核设计的稀疏张量格式。此布局用于优化具有 2:4 稀疏模式的仿射量化张量的存储和计算。
该布局确保张量数据已预处理并以与 Marlin 稀疏内核操作兼容的格式存储。它提供了预处理输入张量和管理量化张量布局的方法。
- pre_process(input: Tensor) Tensor [source]¶
- 预处理输入张量,使其符合 Marlin 稀疏内核的正确格式。
1. 输入张量被转置,因为线性层将权重保留在转置格式中
2. 张量被注入 2:4 稀疏性
3. 再次转置,因为量化过程将计算 dim=-1 的尺度
- 参数:
input (torch.Tensor) – 需要预处理的输入张量
- 返回:
预处理后的张量
- 返回类型: