Benchmark Utils - torch.utils.benchmark#
创建于: 2020年11月02日 | 最后更新于: 2025年06月12日
- class torch.utils.benchmark.Timer(stmt='pass', setup='pass', global_setup='', timer=<built-in function perf_counter>, globals=None, label=None, sub_label=None, description=None, env=None, num_threads=1, language=Language.PYTHON)[source]#
用于测量 PyTorch 语句执行时间的辅助类。
有关如何使用此类,请参阅完整教程: https://pytorch.ac.cn/tutorials/recipes/recipes/benchmark.html
PyTorch Timer 基于 timeit.Timer(实际上内部使用了 timeit.Timer),但有几处关键区别
- 运行时感知
Timer 将执行预热(这对于 PyTorch 的某些元素被延迟初始化很重要),设置线程池大小以使比较公平,并在必要时同步异步 CUDA 函数。
- 关注重复测量
在测量代码,特别是复杂内核/模型时,运行间的变化是一个重要的混杂因素。预期所有测量都应包含重复测量,以量化噪声并允许计算中位数,这比计算平均值更稳健。为此,此类偏离了 timeit API,将 timeit.Timer.repeat 和 timeit.Timer.autorange 在概念上合并。 (具体算法在方法文档字符串中讨论。) 在不需要自适应策略的情况下,会复制 timeit 方法。
- 可选元数据
定义 Timer 时,可以选择指定 label、sub_label、description 和 env。(稍后定义)这些字段包含在结果对象的表示中,并且 Compare 类会使用它们来对结果进行分组和显示,以便进行比较。
- 指令计数
除了壁钟时间外,Timer 还可以运行 Callgrind 下的代码并报告执行的指令数。
与 timeit.Timer 构造函数参数直接对应
stmt、setup、timer、globals
PyTorch Timer 特有的构造函数参数
label、sub_label、description、env、num_threads
- 参数
stmt (str) – 要在循环中运行和计时的代码片段。
setup (str) – 可选的设置代码。用于定义 stmt 中使用的变量。
global_setup (str) – (仅 C++ ) 放置在文件顶层的代码,用于 `#include` 语句等。
timer (Callable[[], float]) – 返回当前时间的调用函数。如果 PyTorch 构建时没有 CUDA 或没有 GPU,则默认为 timeit.default_timer;否则,它将在测量时间之前同步 CUDA。
globals (Optional[dict[str, Any]]) – 一个字典,用于在执行 stmt 时定义全局变量。这是提供 stmt 所需变量的另一种方法。
label (Optional[str]) – 总结 stmt 的字符串。例如,如果 stmt 是 “torch.nn.functional.relu(torch.add(x, 1, out=out))”,则可以设置 label 为 “ReLU(x + 1)” 以提高可读性。
提供补充信息以区分具有相同 stmt 或 label 的测量。例如,在上面的示例中,sub_label 可能是 “float” 或 “int”,以便于区分:“ReLU(x + 1): (float)”
”ReLU(x + 1): (int)” 在打印 Measurement 或使用 Compare 进行汇总时。
用于区分具有相同 label 和 sub_label 的测量的字符串。`description` 的主要用途是向 `Compare` 指示数据列。例如,可以根据输入大小设置它来创建如下格式的表格
| n=1 | n=4 | ... ------------- ... ReLU(x + 1): (float) | ... | ... | ... ReLU(x + 1): (int) | ... | ... | ...
使用 Compare。在打印 Measurement 时也包含在内。
env (Optional[str]) – 此标签表示其他相同的任务是在不同的环境中运行的,因此不相等,例如在 A/B 测试内核更改时。当合并重复运行结果时,Compare 会将具有不同 env 指定的 Measurement 视为不同的。
num_threads (int) – 执行 stmt 时 PyTorch 线程池的大小。单线程性能很重要,因为它既是关键的推理工作负载,也是内在算法效率的良好指标,因此默认设置为一。这与默认的 PyTorch 线程池大小形成对比,后者会尝试利用所有核心。
- adaptive_autorange(threshold=0.1, *, min_run_time=0.01, max_run_time=10.0, callback=None)[source]#
与 blocked_autorange 类似,但还检查测量值的变异性,并重复执行直到 iqr/median 小于 threshold 或达到 max_run_time。
高层来看,adaptive_autorange 执行以下伪代码
`setup` times = [] while times.sum < max_run_time start = timer() for _ in range(block_size): `stmt` times.append(timer() - start) enough_data = len(times)>3 and times.sum > min_run_time small_iqr=times.iqr/times.mean<threshold if enough_data and small_iqr: break
- 参数
- 返回
一个 Measurement 对象,包含测量的运行时间和重复次数,可用于计算统计数据(平均值、中位数等)。
- 返回类型
- blocked_autorange(callback=None, min_run_time=0.2)[source]#
测量多个重复项,同时将计时器开销降至最低。
高层来看,blocked_autorange 执行以下伪代码
`setup` total_time = 0 while total_time < min_run_time start = timer() for _ in range(block_size): `stmt` total_time += (timer() - start)
请注意内循环中的 `block_size` 变量。`block_size` 的选择对测量质量很重要,并且必须平衡两个相互竞争的目标:
较小的 `block_size` 会导致更多的重复测量,通常统计数据更好。
较大的 `block_size` 可以更好地分摊 `timer` 调用的成本,并产生偏差更小的测量。这一点很重要,因为 CUDA 同步时间是不可忽略的(顺序约为几个微秒到十几微秒),否则会使测量产生偏差。
blocked_autorange 通过运行预热周期来设置 `block_size`,并增加 `block_size` 直到计时器开销小于总计算量的 0.1%。然后将此值用于主测量循环。
- 返回
一个 Measurement 对象,包含测量的运行时间和重复次数,可用于计算统计数据(平均值、中位数等)。
- 返回类型
- collect_callgrind(number: int, *, repeats: None, collect_baseline: bool, retain_out_file: bool) CallgrindStats [源代码]#
- collect_callgrind(number: int, *, repeats: int, collect_baseline: bool, retain_out_file: bool) tuple[torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats, ...]
使用 Callgrind 收集指令计数。
与挂钟时间不同,指令计数是确定性的(除了程序本身的非确定性和 Python 解释器带来的一些微小抖动)。这使得它们非常适合详细的性能分析。此方法在单独的进程中运行 stmt,以便 Valgrind 可以检测程序。由于检测,性能会严重下降,但由于通常只需要少量迭代即可获得良好的测量结果,因此可以缓解这种情况。
要使用此方法,必须安装 valgrind、callgrind_control 和 callgrind_annotate。
由于调用方(此进程)和 stmt 执行之间存在进程边界,因此 globals 不能包含任意内存数据结构。(与计时方法不同)相反,globals 仅限于内置函数、nn.Modules 和 TorchScripted 函数/模块,以减少序列化和后续反序列化的意外因素。GlobalsBridge 类对此主题提供了更多详细信息。请特别注意 nn.Modules:它们依赖于 pickle,您可能需要在 setup 中添加一个导入才能正确传输它们。
默认情况下,将收集并缓存空语句的配置文件,以指示有多少指令来自驱动 stmt 的 Python 循环。
- 返回
一个 CallgrindStats 对象,它提供指令计数以及一些用于分析和操作结果的基本工具。
- timeit(number=1000000)[源代码]#
镜像 timeit.Timer.timeit() 的语义。
执行主语句(stmt)number 次。https://docs.pythonlang.cn/3/library/timeit.html#timeit.Timer.timeit
- 返回类型
- class torch.utils.benchmark.Measurement(number_per_run, raw_times, task_spec, metadata=None)[源代码]#
Timer 测量结果。
此类存储给定语句的一个或多个测量结果。它是可序列化的,并为下游使用者提供了几种便捷方法(包括详细的 __repr__)。
- class torch.utils.benchmark.CallgrindStats(task_spec, number_per_run, built_with_debug_symbols, baseline_inclusive_stats, baseline_exclusive_stats, stmt_inclusive_stats, stmt_exclusive_stats, stmt_callgrind_out)[源代码]#
Timer 收集的 Callgrind 结果的顶级容器。
通常通过调用 CallgrindStats.stats(…) 使用 FunctionCounts 类进行操作。还提供了几个便捷方法;其中最重要的是 CallgrindStats.as_standardized()。
- as_standardized()[源代码]#
剥离函数字符串中的库名和一些前缀。
在比较两组不同的指令计数时,一个障碍可能是路径前缀。Callgrind 在报告函数时包含完整的 filepath(它应该如此)。但是,这可能导致 diff 配置文件时出现问题。如果两组配置文件中的 Python 或 PyTorch 等关键组件在不同位置构建,则可能导致类似以下情况:
23234231 /tmp/first_build_dir/thing.c:foo(...) 9823794 /tmp/first_build_dir/thing.c:bar(...) ... 53453 .../aten/src/Aten/...:function_that_actually_changed(...) ... -9823794 /tmp/second_build_dir/thing.c:bar(...) -23234231 /tmp/second_build_dir/thing.c:foo(...)
剥离前缀可以正则化字符串,从而在 diff 时实现更好的等效调用站点取消,从而缓解此问题。
- 返回类型
- delta(other, inclusive=False)[源代码]#
差异两组计数。
收集指令计数的一个常见原因是为了确定特定更改对执行某些工作单位所需的指令数的影响。如果更改增加了该数量,下一个逻辑问题是“为什么”。这通常涉及查看代码的哪个部分指令计数增加了。此函数使此过程自动化,以便可以轻松地在包含和排除的基础上差异计数。
- 返回类型
- class torch.utils.benchmark.FunctionCounts(_data, inclusive, truncate_rows=True, _linewidth=None)[源代码]#
用于操作 Callgrind 结果的容器。
- 它支持
加法和减法以组合或差异结果。
类似元组的索引。
denoise 函数,用于剥离已知非确定性且非常嘈杂的 CPython 调用。
两个高阶方法(filter 和 transform)用于自定义操作。
- denoise()[源代码]#
删除已知嘈杂的指令。
CPython 解释器中的几条指令非常嘈杂。这些指令涉及 Unicode 到字典的查找,Python 使用它们来映射变量名。FunctionCounts 通常是内容无关的容器,但为了获得可靠的结果,这一点足够重要,可以作为例外。
- 返回类型
- class torch.utils.benchmark.Compare(results)[源代码]#
用于在格式化表格中显示多个测量结果的帮助类。
表格格式基于
torch.utils.benchmark.Timer
中提供的信息字段(description、label、sub_label、num_threads 等)。可以通过使用
print()
直接打印表格,或者将其强制转换为 str。有关如何使用此类,请参阅完整教程: https://pytorch.ac.cn/tutorials/recipes/recipes/benchmark.html
- 参数
results (list[torch.utils.benchmark.utils.common.Measurement]) – 要显示的 Measurment 列表。