Benchmark Utils - torch.utils.benchmark#
创建日期:2020 年 11 月 02 日 | 最后更新日期:2025 年 06 月 12 日
- class torch.utils.benchmark.Timer(stmt='pass', setup='pass', global_setup='', timer=<built-in function perf_counter>, globals=None, label=None, sub_label=None, description=None, env=None, num_threads=1, language=Language.PYTHON)[source]#
用于测量 PyTorch 语句执行时间的辅助类。
有关使用此类功能的完整教程,请参阅:https://pytorch.ac.cn/tutorials/recipes/recipes/benchmark.html
PyTorch Timer 基于 timeit.Timer(实际上内部也使用 timeit.Timer),但有几个关键区别
- 运行时感知
Timer 将执行预热(这对于 PyTorch 的某些元素是惰性初始化的,这一点很重要),设置线程池大小以便比较是“苹果对苹果”,并在必要时同步异步加速器函数。
- 专注于副本
在测量代码,特别是复杂内核/模型时,运行-运行变化是一个重要的混淆因素。预计所有测量都应包含副本以量化噪声并允许计算中位数,中位数比平均值更稳健。为此,此类偏离了 timeit API,在概念上合并了 timeit.Timer.repeat 和 timeit.Timer.autorange。(确切的算法在方法文档字符串中讨论。)在不希望使用自适应策略的情况下,复制了 timeit 方法。
- 可选元数据
在定义 Timer 时,可以选择指定 label、sub_label、description 和 env。(稍后定义)这些字段包含在结果对象的表示中,并且 Compare 类用于按组显示结果以进行比较。
- 指令计数
除了挂钟时间之外,Timer 还可以运行 Callgrind 下的语句并报告执行的指令数。
直接对应于 timeit.Timer 构造函数参数
stmt、setup、timer、globals
PyTorch Timer 特定构造函数参数
label、sub_label、description、env、num_threads
- 参数
stmt (str) – 要在循环中运行和计时的一段代码。
setup (str) – 可选的设置代码。用于定义 stmt 中使用的变量。
global_setup (str) – (仅限 C++) 位于文件顶层的代码,用于例如 #include 语句。
timer (Callable[[], float]) – 返回当前时间的调用对象。如果 PyTorch 构建时没有加速器,或者不存在加速器,则此参数默认为 timeit.default_timer;否则,它将在测量时间之前同步加速器。
globals (Optional[dict[str, Any]]) – 在执行 stmt 时定义全局变量的字典。这是为 stmt 所需的变量提供的另一种方法。
label (Optional[str]) – 总结 stmt 的字符串。例如,如果 stmt 是 “torch.nn.functional.relu(torch.add(x, 1, out=out))”,则可以将 label 设置为 “ReLU(x + 1)” 以提高可读性。
提供补充信息,以区分具有相同 stmt 或 label 的测量。例如,在上面的示例中,sub_label 可能是 “float” 或 “int”,这样就可以轻松区分:“ReLU(x + 1): (float)”
“ReLU(x + 1): (int)” 在打印 Measurements 或使用 Compare 进行汇总时。
用于区分具有相同 label 和 sub_label 的测量的字符串。 description 的主要用途是向 Compare 表明数据的列。例如,您可以根据输入大小设置它,以创建如下表所示的表格
| n=1 | n=4 | ... ------------- ... ReLU(x + 1): (float) | ... | ... | ... ReLU(x + 1): (int) | ... | ... | ...
使用 Compare。在打印 Measurement 时也会包含它。
env (Optional[str]) – 此标签表示在不同环境中运行的、否则相同的任务,因此不等价,例如在对内核更改进行 A/B 测试时。Compare 将具有不同 env 指定的 Measurement 视为不同的,以便合并重复运行。
num_threads (int) – 执行 stmt 时 PyTorch 线程池的大小。单线程性能很重要,因为它既是关键的推理工作负载,也是内在算法效率的良好指标,因此默认设置为一。这与默认的 PyTorch 线程池大小(尝试利用所有核心)相反。
- adaptive_autorange(threshold=0.1, *, min_run_time=0.01, max_run_time=10.0, callback=None)[source]#
类似于 blocked_autorange,但还检查测量值的变异性,并重复直到 iqr/median 小于 threshold 或达到 max_run_time。
总的来说,adaptive_autorange 执行以下伪代码
`setup` times = [] while times.sum < max_run_time start = timer() for _ in range(block_size): `stmt` times.append(timer() - start) enough_data = len(times)>3 and times.sum > min_run_time small_iqr=times.iqr/times.mean<threshold if enough_data and small_iqr: break
- 参数
- 返回
一个 Measurement 对象,包含测量到的运行时间和重复次数,可用于计算统计量(平均值、中位数等)。
- 返回类型
- blocked_autorange(callback=None, min_run_time=0.2)[source]#
测量许多副本,同时将计时器开销降至最低。
总的来说,blocked_autorange 执行以下伪代码
`setup` total_time = 0 while total_time < min_run_time start = timer() for _ in range(block_size): `stmt` total_time += (timer() - start)
请注意内循环中的 block_size 变量。块大小的选择对于测量质量很重要,必须在两个相互竞争的目标之间进行权衡
较小的块大小会导致更多的副本,通常具有更好的统计数据。
较大的块大小能更好地分摊 timer 调用的成本,并导致更少的偏差测量。这一点很重要,因为加速器同步时间并非微不足道(数量级为微秒级别),否则会使测量产生偏差。
blocked_autorange 通过运行预热期来设置 block_size,增加 block_size 直到计时器开销小于总计算时间的 0.1%。然后将此值用于主测量循环。
- 返回
一个 Measurement 对象,包含测量到的运行时间和重复次数,可用于计算统计量(平均值、中位数等)。
- 返回类型
- collect_callgrind(number: int, *, repeats: None, collect_baseline: bool, retain_out_file: bool) CallgrindStats [source]#
- collect_callgrind(number: int, *, repeats: int, collect_baseline: bool, retain_out_file: bool) tuple[torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats, ...]
使用 Callgrind 收集指令计数。
与挂钟时间不同,指令计数是确定性的(除了程序本身的非确定性和 Python 解释器产生的小量抖动)。这使得它们非常适合详细的性能分析。此方法在一个单独的进程中运行 stmt,以便 Valgrind 可以对程序进行仪表化。由于仪表化,性能会严重下降,但由于少量迭代通常足以获得良好的测量结果,这得到了缓解。
要使用此方法,必须安装 valgrind、callgrind_control 和 callgrind_annotate。
由于调用者(此进程)和 stmt 执行之间存在进程边界,因此 globals 不能包含任意内存中数据结构。(与计时方法不同)相反,globals 仅限于内置函数、nn.Modules 和 TorchScripted 函数/模块,以减少序列化和后续反序列化的意外因素。GlobalsBridge 类对此主题提供了更多详细信息。请特别注意 nn.Modules:它们依赖于 pickle,您可能需要将 import 添加到 setup 中才能使它们正确传输。
默认情况下,将收集并缓存一个空语句的配置文件,以指示驱动 stmt 的 Python 循环产生了多少指令。
- 返回
一个 CallgrindStats 对象,提供指令计数以及一些基本分析和操作结果的工具。
- timeit(number=1000000)[source]#
镜像 timeit.Timer.timeit() 的语义。
执行主语句(stmt)number 次。https://docs.pythonlang.cn/3/library/timeit.html#timeit.Timer.timeit
- 返回类型
- class torch.utils.benchmark.Measurement(number_per_run, raw_times, task_spec, metadata=None)[source]#
Timer 测量结果。
此类存储给定语句的一个或多个测量值。它是可序列化的,并为下游消费者提供了几个方便的方法(包括详细的 __repr__)。
- class torch.utils.benchmark.CallgrindStats(task_spec, number_per_run, built_with_debug_symbols, baseline_inclusive_stats, baseline_exclusive_stats, stmt_inclusive_stats, stmt_exclusive_stats, stmt_callgrind_out)[source]#
Timer 收集的 Callgrind 结果的顶级容器。
通常通过调用 CallgrindStats.stats(…) 来进行操作,这会得到 FunctionCounts 类。还提供了几个便捷方法;其中最重要的是 CallgrindStats.as_standardized()。
- as_standardized()[source]#
剥离函数字符串中的库名称和一些前缀。
在比较两组不同的指令计数时,一个障碍可能是路径前缀。Callgrind 在报告函数时包含完整的文件路径(它应该如此)。然而,这在 diff 配置文件时可能会导致问题。如果两份配置文件中的关键组件(如 Python 或 PyTorch)在不同的位置构建,可能会导致类似以下内容的情况:
23234231 /tmp/first_build_dir/thing.c:foo(...) 9823794 /tmp/first_build_dir/thing.c:bar(...) ... 53453 .../aten/src/Aten/...:function_that_actually_changed(...) ... -9823794 /tmp/second_build_dir/thing.c:bar(...) -23234231 /tmp/second_build_dir/thing.c:foo(...)
剥离前缀可以通过规范化字符串来缓解此问题,并在 diff 时更好地消除等效调用站点。
- 返回类型
- delta(other, inclusive=False)[source]#
diff 两个计数集。
收集指令计数的一个常见原因是确定特定更改对执行某些工作单元所需的指令数的影响。如果更改增加了该数字,下一个合乎逻辑的问题是“为什么”。这通常涉及查看代码的哪个部分指令数增加了。此函数可以自动化此过程,以便可以轻松地在包含和排除的基础上 diff 计数。
- 返回类型
- class torch.utils.benchmark.FunctionCounts(_data, inclusive, truncate_rows=True, _linewidth=None)[source]#
用于操作 Callgrind 结果的容器。
- 它支持
加法和减法以组合或 diff 结果。
类似元组的索引。
denoise 函数,它会剥离已知的非确定性且非常嘈杂的 CPython 调用。
两个高阶方法(filter 和 transform)用于自定义操作。
- denoise()[source]#
移除已知嘈杂的指令。
CPython 解释器中的一些指令非常嘈杂。这些指令涉及 Unicode 到字典的查找,Python 使用它们来映射变量名。FunctionCounts 通常是内容无关的容器,但是,为了获得可靠的结果,这一点足够重要,值得破例。
- 返回类型
- class torch.utils.benchmark.Compare(results)[source]#
用于在格式化表中显示多个测量结果的帮助类。
表格式基于
torch.utils.benchmark.Timer
中提供的信息字段(description、label、sub_label、num_threads 等)。表可以直接使用
print()
打印,或转换为 str。有关使用此类功能的完整教程,请参阅:https://pytorch.ac.cn/tutorials/recipes/recipes/benchmark.html
- 参数
results (list[torch.utils.benchmark.utils.common.Measurement]) – 要显示的 Measurement 列表。