评价此页

使用 torch.compile 性能分析#

创建于:2023年6月6日 | 最后更新于:2025年6月10日

torch.profiler 的用途:#

torch.profiler 有助于以内核粒度了解程序的性能 - 例如,它可以显示图的断点和程序级别的资源利用率。剖析器提供的数据通常可以帮助用户了解在哪里可以进一步调查以理解模型性能。

要了解内核级别的性能,还有其他工具,例如 Nvidia Nsight compute toolAMD Omnitrace、Intel® VTune™ Profiler 或 inductor 的性能分析工具

另请参阅 通用 pytorch 性能分析指南

torch.profiler 的使用和跟踪视图基础#

示例程序:我们将使用此示例来分析 resnet18 的性能。请注意此示例程序中的以下部分

  • 包含一个预热运行以等待编译完成(这将预热 CUDA 缓存分配器等系统)

  • 使用 torch.profiler.profile() 上下文来分析我们感兴趣的部分

  • 使用 prof.export_chrome_trace("trace.json") 导出性能分析构件。


    import torch
    from torchvision.models import resnet18

    device = 'cuda'      # or 'cpu', 'xpu', etc.
    model = resnet18().to(device)

    inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)]

    model_c = torch.compile(model)

    def fwd_bwd(inp):
        out = model_c(inp)
        out.sum().backward()

    # warm up
    fwd_bwd(inputs[0])

    with torch.profiler.profile() as prof:
        for i in range(1, 4):
            fwd_bwd(inputs[i])
            prof.step()

    prof.export_chrome_trace("trace.json")

查看 chrome 跟踪:在 Chrome 浏览器中,打开 chrome://tracing 并加载 json 文件。使用“w”和“s”键放大和缩小,使用“a”和“d”键左右滚动。“?” 将显示一个“帮助”屏幕,其中包含快捷方式列表。

Example of a basic chrome trace, visualized in the chrome://tracing viewer

在这里,我们观察到

  • CompiledFunction 和 CompiledFunctionBackward 事件,它们对应于 dynamo 编译的区域。

  • 顶部的 CPU 事件,底部的 GPU 事件。

CPU 和加速器事件之间的流

加速器上的每个内核都在由 CPU 上运行的代码启动后发生。性能分析器可以绘制加速器和 CPU 事件之间的连接(即“流”)以显示哪个 CPU 事件启动了加速器内核。这特别有用,因为,除了少数例外,加速器内核是异步启动的。

要查看流连接,请单击 GPU 内核并单击“ac2g”

Visualization in the chrome://trace viewer, showing an async flow between a kernel and its launching location.

或者,在顶部使用“Flow events”下拉菜单打开 *所有* 流。

解决 CUDA Graph 性能分析问题#

启用 CUDA 图时,某些 CUDA 配置(驱动程序版本低于 525.85.12 或 CUDA < 12)可能会在性能分析工具和 CUDA 图之间遇到问题。要解决这些问题,请在程序顶部添加一个空的性能分析上下文


    import torch

    torch.profiler._utils._init_for_cuda_graphs()

    # ... rest of program

了解编译时间#

要了解为什么编译需要很长时间,您可以分析 torch.compile 编译程序的第一次调用。请记住,编译的性能分析跟踪可能比典型性能分析更容易失真,因为编译工作负载可能与典型的 PyTorch 工作负载非常不同。在某些情况下,跟踪文件也可能非常大。跟踪文件大于 1GB 可能难以使用 chrome 跟踪工具打开。

注意:使用 :code:torch._dynamo.utils.compile_times() 也可以以非图形格式获取大致相同的信息。此实用程序不会显示编译步骤何时发生,但它会显示每一步花费的时间 - 并且时间不会受到任何性能分析开销的影响。

示例如下


    import torch
    from torchvision.models import resnet18

    # user can switch between cuda and xpu
    device = 'cuda'
    model = resnet18().to(device)
    inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)]

    model_c = torch.compile(model)

    def fwd_bwd(inp):
        out = model_c(inp)
        out.sum().backward()

    def warmup_compile():
        def fn(x):
            return x.sin().relu()

        x = torch.rand((2, 2), device=device, requires_grad=True)
        fn_c = torch.compile(fn)
        out = fn_c(x)
        out.sum().backward()

    with torch.profiler.profile() as prof:
        with torch.profiler.record_function("warmup compile"):
            warmup_compile()

        with torch.profiler.record_function("resnet18 compile"):
            fwd_bwd(inputs[0])

    prof.export_chrome_trace("trace_compile.json")
A visualization in the chrome://trace viewer, showing dynamo and inductor compilation steps

请注意几点

  • 第一次调用应在性能分析 *期间* 发生,以便捕获编译

  • 添加预热编译以初始化任何需要延迟初始化的系统。

查找图中断:“Torch-Compiled Region”和“CompiledFunction”#

尽管有用于识别图中断的日志记录工具,但性能分析器提供了一种快速的可视化方法来识别 :ref: 中断 <torch.compiler_graph_breaks>。有两个性能分析事件需要查找:**Torch-Compiled Region** 和 **CompiledFunction**。

**Torch-Compiled Region** - 该区域是在 PyTorch 2.2 中引入的 - 是一个覆盖整个编译区域的性能分析事件。图中断几乎总是看起来相同:嵌套的“Torch-Compiled Region”事件。

如果您使用分别应用于每个函数的两个独立函数,则通常会期望看到两个相邻(即*不*堆叠/嵌套)的 Torch-Compiled 区域。同时,如果您遇到图中断(或 disable() 的/跳过的区域),则期望看到嵌套的“Torch-Compiled Region”事件。

**CompiledFunction** - 在 PyTorch 2.0 中引入 - 是当需要任何输入的梯度时出现的性能分析事件。每个图中断都会中断一个 CompiledFunction 块,将其分成两半。CompiledFunction 事件仅在涉及 Autograd 时出现,即图的输入张量中有 requires_grad=True。

当跟踪中出现 CompiledFunction 时,它通常与反向传播中的 CompiledFunctionBackward 事件配对。如果调用了反向函数,则“fwd-bwd 链接”应出现在连接两者的跟踪中。

如果您的用例包含不需要 grad 且不包含“Torch-Compiled Region”事件的图,那么要确定 torch.compile 是否正在正确应用可能会更加困难。一个线索可能是 Inductor 生成的 Triton 内核的存在。

示例如下


    import torch
    import torch._dynamo
    # user can switch between cuda and xpu
    device = 'cuda'

    class ModelWithBreaks(torch.nn.Module):
        def __init__(self):
            super().__init__()
            def create_sequential():
                return torch.nn.Sequential(
                    torch.nn.Linear(128, 128),
                    torch.nn.ReLU(),
                    torch.nn.Linear(128, 128),
                    torch.nn.ReLU(),
                )
            self.mod1 = create_sequential()
            self.mod2 = create_sequential()
            self.mod3 = create_sequential()
            self.mod4 = create_sequential()

        def forward(self, inp):
            mod1 = self.mod1(inp)
            torch._dynamo.graph_break()
            mod2 = self.mod2(mod1)
            torch._dynamo.graph_break()
            mod3 = self.mod3(mod2)
            torch._dynamo.graph_break()
            mod4 = self.mod4(mod3)
            return mod4

    model = ModelWithBreaks().to(device)
    inputs = [torch.randn((128, 128), device=device) for _ in range(10)]

    model_c = torch.compile(model)

    def fwd_bwd(inp):
        out = model_c(inp)
        out.sum().backward()

    # warm up
    fwd_bwd(inputs[0])

    with torch.profiler.profile() as prof:
        for i in range(1, 4):
            fwd_bwd(inputs[i])
            prof.step()

    prof.export_chrome_trace("trace_break.json")
Visualization in the chrome://trace viewer, showing nested Torch-Compiled Region events and multiple CompiledFunction events - indicating graph breaks.

运算符内核#

当启动运算符时,我们期望看到几个事件

  1. CPU 端事件

  2. 内核启动(如果处理 GPU 内核)

  3. GPU 端事件

Visualization in the chrome://trace viewer, showing the three types of events - CPU-side event, kernel launch, and GPU-side event

Inductor 生成的 Triton 内核

  1. **CPU 端事件**应显示为以“triton_”为前缀的事件。目前事件信息很少 - 内核名称和启动,但信息比典型的 aten 内核启动(包含输入形状、类型等)少。

  2. **内核启动**应显示为 cuLaunchKernel 而不是 cudaLaunchKernel(cudaLaunchKernel 是 aten 运算符的典型)。

  3. **GPU 端事件**应显示,其名称的描述性取决于 inductor 配置的 unique_kernel_names。

_images/triton_kernel_launch.png

非 Inductor 生成的 Triton 内核

  1. **CPU 端**事件可能不会出现在跟踪中;插入性能分析事件的机制目前是在 Inductor 级别实现的,因此绕过 Inductor 的 Triton 内核可能不会出现在跟踪中,除非用户手动注释了它们。

  2. **内核启动**应显示为 cuLaunchKernel 而不是 cudaLaunchKernel(cudaLaunchKernel 是 aten 运算符的典型)。

  3. **GPU 端**事件应显示,其命名方式类似于所编写的 triton 内核。

_images/noninductor_triton_kernel.png

Inductor 生成的 CPU 内核

  1. **CPU 端事件**不会出现在跟踪中;我们尚未为此添加性能分析。

  2. **内核启动**和**GPU 端事件**不存在

**非 Triton 内核**(即 aten 内核或自定义运算符)也可能偶尔出现在跟踪中。有时,Inductor 会回退到原始运算符实现,在这种情况下,您将看到对 aten 运算符的调用。

启动开销#

一个常见的 GPU 利用率问题是糟糕的 GPU 利用率。识别此问题的一种快捷方法是 GPU 上内核之间是否存在大间隙

Visualization in the chrome://trace viewer, showing large gaps between GPU kernels. This indicates that the model is CPU bound, likely due to overhead during kernel launches.

这通常是 CPU 开销的结果,例如,如果 CPU 之间启动内核所花费的时间大于 GPU 处理内核所花费的时间。对于小批量大小,此问题更常见。

当使用 inductor 时,启用 CUDA 图通常有助于提高启动开销存在问题的性能。