评价此页

(beta) 使用 FX 构建简单的 CPU 性能分析器#

创建日期: 2021年3月4日 | 最后更新: 2025年7月14日 | 最后验证: 未验证

作者: James Reed

在本教程中,我们将使用 FX 来执行以下操作:

  1. 以一种可以检查并收集代码结构和执行统计信息的方式捕获 PyTorch Python 代码。

  2. 构建一个小类,用作简单的性能“分析器”,从实际运行中收集模型各部分的运行时统计信息。

在本教程中,我们将使用 torchvision 的 ResNet18 模型进行演示。

import torch
import torch.fx
import torchvision.models as models

rn18 = models.resnet18()
rn18.eval()
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, bias=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

现在我们有了模型,想要深入检查其性能。也就是说,在接下来的调用中,模型的哪些部分耗时最长?

input = torch.randn(5, 3, 224, 224)
output = rn18(input)

回答这个问题的一种常见方法是遍历程序源代码,添加在程序各处收集时间戳的代码,并比较这些时间戳之间的差异,从而查看时间戳之间的区域耗时多久。

该技术当然适用于 PyTorch 代码,但如果我们不必复制模型代码并进行编辑(尤其是我们未编写的代码,如这个 torchvision 模型),那就更好了。相反,我们将使用 FX 来自动化此“植入(instrumentation)”过程,而无需修改任何源代码。

首先,让我们完成一些导入操作(稍后在代码中我们将使用所有这些导入)。

import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter

注意

tabulate 是一个外部库,不是 PyTorch 的依赖项。我们将使用它来更轻松地可视化性能数据。请确保您已从喜欢的 Python 包源安装了它。

使用符号追踪(Symbolic Tracing)捕获模型#

接下来,我们将使用 FX 的符号追踪机制,以我们可以操作和检查的数据结构来捕获模型定义。

traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
    %x : torch.Tensor [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
    %maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
    %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
    %layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
    %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
    %layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
    %layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
    %layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
    %layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
    %layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
    %layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
    %layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
    %layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
    %layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
    %layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
    %layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
    %layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
    %layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
    %layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
    %layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
    %layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
    %layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
    %layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
    %layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
    %add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
    %layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
    %layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
    %layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
    %layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
    %layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
    %layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
    %layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
    %layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
    %layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
    %layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
    %layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
    %layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
    %add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
    %layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
    %layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
    %layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
    %layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
    %layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
    %layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
    %layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
    %layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
    %layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
    %layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
    %layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
    %layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
    %add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
    %layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
    %avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
    %flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
    return fc

这为我们提供了 ResNet18 模型的 Graph 表示。Graph 由一系列相互连接的 Node 组成。每个 Node 代表 Python 代码中的一个调用点(无论是函数、模块还是方法),而边(在每个节点上表示为 argskwargs)代表在这些调用点之间传递的值。有关 Graph 表示和其余 FX API 的更多信息,请参阅 FX 文档 https://pytorch.ac.cn/docs/stable/fx.html

创建分析解释器(Profiling Interpreter)#

接下来,我们将创建一个继承自 torch.fx.Interpreter 的类。虽然 symbolic_trace 产生的 GraphModule 会编译 Python 代码,在您调用 GraphModule 时运行,但运行 GraphModule 的另一种方法是逐个执行 Graph 中的每个 Node。这就是 Interpreter 提供的功能:它逐节点解释图形。

通过继承 Interpreter,我们可以覆盖各种功能并安装我们想要的分析行为。目标是拥有一个对象,我们可以向其传递模型,调用模型一次或多次,然后获取在这些运行期间模型及其各部分耗时的统计信息。

让我们定义我们的 ProfilingInterpreter

class ProfilingInterpreter(Interpreter):
    def __init__(self, mod : torch.nn.Module):
        # Rather than have the user symbolically trace their model,
        # we're going to do it in the constructor. As a result, the
        # user can pass in any ``Module`` without having to worry about
        # symbolic tracing APIs
        gm = torch.fx.symbolic_trace(mod)
        super().__init__(gm)

        # We are going to store away two things here:
        #
        # 1. A list of total runtimes for ``mod``. In other words, we are
        #    storing away the time ``mod(...)`` took each time this
        #    interpreter is called.
        self.total_runtime_sec : List[float] = []
        # 2. A map from ``Node`` to a list of times (in seconds) that
        #    node took to run. This can be seen as similar to (1) but
        #    for specific sub-parts of the model.
        self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}

    ######################################################################
    # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
    # method is the top-level entry point for execution of the model. We will
    # want to intercept this so that we can record the total runtime of the
    # model.

    def run(self, *args) -> Any:
        # Record the time we started running the model
        t_start = time.time()
        # Run the model by delegating back into Interpreter.run()
        return_val = super().run(*args)
        # Record the time we finished running the model
        t_end = time.time()
        # Store the total elapsed time this model execution took in the
        # ``ProfilingInterpreter``
        self.total_runtime_sec.append(t_end - t_start)
        return return_val

    ######################################################################
    # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
    # time it executes a single node. We will intercept this so that we
    # can measure and record the time taken for each individual call in
    # the model.

    def run_node(self, n : torch.fx.Node) -> Any:
        # Record the time we started running the op
        t_start = time.time()
        # Run the op by delegating back into Interpreter.run_node()
        return_val = super().run_node(n)
        # Record the time we finished running the op
        t_end = time.time()
        # If we don't have an entry for this node in our runtimes_sec
        # data structure, add one with an empty list value.
        self.runtimes_sec.setdefault(n, [])
        # Record the total elapsed time for this single invocation
        # in the runtimes_sec data structure
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val

    ######################################################################
    # Finally, we are going to define a method (one which doesn't override
    # any ``Interpreter`` method) that provides us a nice, organized view of
    # the data we have collected.

    def summary(self, should_sort : bool = False) -> str:
        # Build up a list of summary information for each node
        node_summaries : List[List[Any]] = []
        # Calculate the mean runtime for the whole network. Because the
        # network may have been called multiple times during profiling,
        # we need to summarize the runtimes. We choose to use the
        # arithmetic mean for this.
        mean_total_runtime = statistics.mean(self.total_runtime_sec)

        # For each node, record summary statistics
        for node, runtimes in self.runtimes_sec.items():
            # Similarly, compute the mean runtime for ``node``
            mean_runtime = statistics.mean(runtimes)
            # For easier understanding, we also compute the percentage
            # time each node took with respect to the whole network.
            pct_total = mean_runtime / mean_total_runtime * 100
            # Record the node's type, name of the node, mean runtime, and
            # percent runtime.
            node_summaries.append(
                [node.op, str(node), mean_runtime, pct_total])

        # One of the most important questions to answer when doing performance
        # profiling is "Which op(s) took the longest?". We can make this easy
        # to see by providing sorting functionality in our summary view
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)

        # Use the ``tabulate`` library to create a well-formatted table
        # presenting our summary information
        headers : List[str] = [
            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
        ]
        return tabulate.tabulate(node_summaries, headers=headers)

注意

我们使用 Python 的 time.time 函数来获取挂钟时间戳并进行比较。这不是衡量性能最准确的方法,只会给我们一个一阶近似值。在本教程中,我们仅出于演示目的使用此简单技术。

调查 ResNet18 的性能#

现在我们可以使用 ProfilingInterpreter 来检查 ResNet18 模型的性能特征;

interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type        Op                       Average runtime (s)    Pct total runtime
-------------  ---------------------  ---------------------  -------------------
call_module    maxpool                          0.00528169             9.19708
call_module    conv1                            0.00460577             8.02009
call_module    layer4_1_conv1                   0.00317168             5.5229
call_module    layer4_0_conv2                   0.00315022             5.48553
call_module    layer1_0_conv1                   0.00299764             5.21983
call_module    layer4_1_conv2                   0.00292659             5.09611
call_module    layer1_1_conv1                   0.00267553             4.65894
call_module    layer1_0_conv2                   0.00251722             4.38328
call_module    layer2_1_conv2                   0.00232363             4.04617
call_module    layer3_1_conv1                   0.00226665             3.94694
call_module    layer2_1_conv1                   0.00224924             3.91664
call_module    layer2_0_conv2                   0.0022471              3.9129
call_module    layer3_1_conv2                   0.00222349             3.8718
call_module    layer1_1_conv2                   0.00218534             3.80537
call_module    layer3_0_conv2                   0.0021193              3.69037
call_module    layer4_0_conv1                   0.00193357             3.36696
call_module    layer2_0_conv1                   0.00166345             2.89658
call_module    layer3_0_conv1                   0.0012784              2.2261
call_module    layer2_0_downsample_0            0.00105214             1.83211
call_module    bn1                              0.000759363            1.32229
call_function  add                              0.000511169            0.890107
call_module    layer3_0_downsample_0            0.000475645            0.828248
call_module    layer4_0_downsample_0            0.000448704            0.781334
call_function  add_1                            0.000395775            0.689168
call_module    layer1_1_bn1                     0.000281334            0.489891
call_module    fc                               0.000258923            0.450866
call_function  add_3                            0.000252008            0.438826
call_module    relu                             0.000247955            0.431768
call_module    layer1_0_bn1                     0.000231504            0.403122
call_module    layer1_0_bn2                     0.000210762            0.367003
call_module    layer2_0_bn1                     0.000196218            0.341678
call_module    layer2_0_bn2                     0.000184536            0.321335
call_module    layer2_1_bn2                     0.000177383            0.30888
call_module    layer2_0_downsample_1            0.000157833            0.274837
call_module    layer1_1_bn2                     0.00015521             0.27027
call_module    layer3_1_bn1                     0.000134468            0.234151
call_module    avgpool                          0.000122309            0.212978
call_module    layer4_0_bn2                     0.000120163            0.209241
call_module    layer4_1_bn2                     0.000118017            0.205505
call_module    layer4_1_bn1                     0.000116348            0.202599
call_module    layer3_1_bn2                     0.00011301             0.196787
call_module    layer4_0_relu                    0.000105858            0.184332
call_module    layer3_0_bn2                     0.000104189            0.181426
call_function  add_2                            0.000100136            0.174368
call_module    layer1_0_relu_1                  9.67979e-05            0.168556
call_module    layer2_1_bn1                     9.60827e-05            0.16731
call_module    layer3_0_relu_1                  9.39369e-05            0.163574
call_module    layer1_0_relu                    9.34601e-05            0.162743
call_module    layer3_0_downsample_1            8.72612e-05            0.151949
call_module    layer1_1_relu_1                  8.58307e-05            0.149458
call_function  add_5                            7.79629e-05            0.135758
call_module    layer4_0_downsample_1            7.67708e-05            0.133682
call_module    layer3_0_bn1                     7.60555e-05            0.132437
call_module    layer1_1_relu                    7.48634e-05            0.130361
call_module    layer4_0_bn1                     7.43866e-05            0.12953
call_module    layer4_1_relu                    6.77109e-05            0.117906
call_function  add_7                            6.55651e-05            0.114169
call_function  add_6                            6.29425e-05            0.109603
call_module    layer2_0_relu                    4.98295e-05            0.0867688
call_module    layer2_0_relu_1                  4.98295e-05            0.0867688
call_module    layer2_1_relu_1                  4.95911e-05            0.0863536
call_module    layer2_1_relu                    4.60148e-05            0.0801262
call_module    layer4_0_relu_1                  4.57764e-05            0.079711
call_function  add_4                            4.3869e-05             0.0763898
call_module    layer4_1_relu_1                  4.22001e-05            0.0734836
call_module    layer3_1_relu                    3.98159e-05            0.069332
call_module    layer3_0_relu                    3.86238e-05            0.0672562
call_module    layer3_1_relu_1                  3.57628e-05            0.0622743
placeholder    x                                3.43323e-05            0.0597833
call_function  flatten                          2.76566e-05            0.0481588
output         output                           1.19209e-05            0.0207581

这里有两点我们应该指出

结论#

正如我们所见,使用 FX,我们可以轻松地以机器可解释的格式捕获 PyTorch 程序(即使是我们没有源代码的程序!),并将其用于分析,例如我们在此处所做的性能分析。FX 为使用 PyTorch 程序开辟了一个令人兴奋的可能性世界。

最后,由于 FX 仍处于 beta 阶段,我们很乐意听到您使用它的任何反馈。请随时使用 PyTorch 论坛 (https://discuss.pytorch.org/) 和问题追踪器 (pytorch/pytorch#issues) 提供您可能有的任何反馈。

脚本总运行时间: (0 分 0.308 秒)