评价此页

使用 CPU 上的 Max-Autotune 编译以获得更好性能#

作者Jiong Gong, Leslie Fang, Chunyuan Wu

在本教程中,您将学习如何通过利用 Inductor CPU 后端的 max-autotune 模式来提升 PyTorch 模型在 CPU 上的性能。探索激活过程,理解与传统方法的区别,并将 max-autotune 集成到您的代码中以提高计算效率。深入了解使用高级 GEMM 模板以实现更快的处理和卓越的运行时性能。

先决条件:#

简介#

torch.compile ( RFC 链接) 的 Inductor CPU 后端的 max-autotune 模式在编译时分析多个操作实现,并选择性能最佳的实现,从而牺牲更长的编译时间来换取改进的运行时性能。此增强功能对于 GEMM 相关操作尤其有益。在 Inductor CPU 后端,我们引入了一种基于 C++ 模板的 GEMM 实现,作为依赖 oneDNN 和 MKL 库的基于 ATen 的方法的替代方案。这类似于 CUDA 上的 max-autotune 模式,其中会考虑 ATen、Triton 和 CUTLASS 的实现。

我们已经涵盖了大多数流行的数据类型,包括 FP32、BF16、FP16 和 INT8,并为 x86 CPU 提供了外插融合。

尽管开发仍在进行中,但根据三个基准测试套件和 LLM 推理的衡量,我们已经看到了比纯 ATen 基于 GEMM 的性能有显著提升。

激活 max-autotune 模式#

要激活 PyTorch 中的 max-autotune 模式,请在使用 torch.compile 编译模型时将 mode 参数设置为 max-autotune。如果您希望绕过调优过程并始终使用 C++ 模板实现,可以通过环境变量进行配置:export TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=CPP

示例#

下面的代码是一个在简单神经网络上使用 max-autotune 模式的示例,该网络包含一个线性层后跟一个 ReLU 激活。

在基于 C++ 模板的 GEMM 实现中,我们将预先打包权重以获得良好的缓存利用。在推理(CPU AI 工作负载的主要场景)的情况下,模型权重是恒定的,我们在编译期间预先打包它们,以便数据访问在缓存块内是连续的。因此,我们仅支持带有 torch.no_grad 或推理模式的冻结模型。您需要设置环境变量 export TORCHINDUCTOR_FREEZING=1 并确保编译和推理步骤都在 torch.no_grad 上下文中执行。

import torch
from torch._inductor import config
config.trace.log_autotuning_results = True # enable the log of autotuning results

class M(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        bias,
        **kwargs,
    ):
        super().__init__()
        self.linear = torch.nn.Linear(
            in_features,
            out_features,
            bias,
            **kwargs,
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        return x

amp_enabled = True
batch_size = 64
in_features = 16
out_features = 32
bias = True

x = torch.randn(batch_size, in_features)
model = M(in_features, out_features, bias)

with torch.no_grad(), torch.cpu.amp.autocast(enabled=amp_enabled):
    compiled = torch.compile(model, mode="max-autotune") # turn on "max-autotune" mode
    y = compiled(x)

运行上述代码片段时,您将看到自动调优结果(性能数字仅用于演示目的)。在此示例中,C++ 模板的性能优于 ATen 内核,因此将被选中。

AUTOTUNE linear_unary(64x16, 32x16, 32)
cpp_packed_gemm_0 0.2142 ms 100.0%
_linear_pointwise 0.2441 ms 87.7%

我们可以通过设置 export TORCH_LOGS="+output_code" 来检查生成的输出代码。当选择 C++ 模板时,我们不再在生成的代码中看到 torch.ops.mkldnn._linear_pointwise.default(用于 bfloat16)或 torch.ops.mkl._mkl_linear.default(用于 float32),而是会找到基于 CPP GEMM 模板 cpp_fused__to_copy_relu_1(为简洁起见,仅展示部分代码)的内核,并将偏置和 relu 外插融合到 C++ GEMM 模板内核中。

生成的代码因 CPU 架构而异,并且是实现特定的,可能会发生更改。

cpp_fused__to_copy_relu_1 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*'], '''

...

template <bool accum>
inline void kernel_micro_gemm_amx_kernel_32_2(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc,
    uint8_t tilecfg_rows
) {
    ...
}

...

template <bool accum>
inline void kernel_micro_gemm(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t M,
    int64_t N,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc
) {
    ...
}

extern "C"
void kernel(const bfloat16* X, const bfloat16* W, const bfloat16* inp, bfloat16* Y)
{
    constexpr int64_t num_threads = 40;
    constexpr int64_t N = 32;
    constexpr int64_t K = 16;
    constexpr int64_t M = static_cast<int64_t>(64L);
    ...
    #pragma omp parallel num_threads(40)
    {
        const int tid = omp_get_thread_num();
        ...
        for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
            ...
            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
                ...
                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
                    ...
                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
                        if (kc == k_block_start) {
                            kernel_micro_gemm<static_cast<bool>(false)>(
                                ...
                            );

                        } else {
                            kernel_micro_gemm<static_cast<bool>(true)>(
                                ...
                            );

                        }
                    }
                }
                {
                    {
                        // Epilogue fusion here for bias and relu
                        #pragma GCC ivdep
                        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(inp + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(16));
                                auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
                                auto tmp1 = at::vec::convert<float>(tmp0);
                                auto tmp3 = tmp1 + tmp2;
                                auto tmp4 = at::vec::convert<bfloat16>(tmp3);
                                auto tmp5 = static_cast<float>(0.0);
                                auto tmp6 = at::vec::Vectorized<float>(tmp5);
                                auto tmp7 = at::vec::maximum(tmp3, tmp6);
                                auto tmp8 = at::vec::convert<bfloat16>(tmp7);
                                tmp8.store(Y + static_cast<int64_t>(n_start + x1 + (32L*m_start) + (32L*x0)), static_cast<int64_t>(16));
                            }

                            ...

                        }
                    }

                }
            }
        }
        ...
    }
}
''')

结论#

在本教程中,我们介绍了 CPU 上的 max-autotune 支持和 GEMM 模板。我们解释了激活此功能的 API,并演示了 GEMM 模板生成的代码。

此功能处于原型阶段。如果您有任何功能请求或遇到任何问题,请在 GitHub issues 上提交 bug 报告。