PyTorch/XLA 中的 Tracing 时间 vs. Execution 时间¶
在使用 PyTorch/XLA 时,重要的是要理解 XLA 张量上的操作通常不会像在 CPU 或 CUDA 设备上的标准 PyTorch 张量(在“急切模式”下运行)那样立即执行。PyTorch/XLA 采用“延迟执行”模型。这意味着当你使用 XLA 张量编写 PyTorch 代码时,你主要是在定义或跟踪计算图。计算图的编译及其在设备上的后续执行会被推迟到特定的触发点。
这导致需要考虑两种不同的“时间”:
Host-Side Time(主机端时间):CPU(主机)准备计算的期间。这包括:
Tracing Time(跟踪时间):PyTorch/XLA 记录操作并构建计算图的期间。
Compilation Time(编译时间):主机端 XLA 编译器将跟踪的图转换为优化的设备代码所需的时间。这在新图的第一次执行或图发生变化时最为显著。
Device Time(设备时间):这主要是 Execution Time(执行时间),即 XLA 设备(例如 TPU)运行编译后的代码的期间。
说明常见陷阱:仅测量 Tracing 时间¶
当您使用 XLA 张量(例如,TPU 上的张量)编写 PyTorch 代码时,PyTorch/XLA 不会立即在设备上执行每个操作。它会跟踪这些操作,将它们添加到内部计算图中。如果您测量仅执行 XLA 操作而没有显式等待设备的指令的代码的持续时间,那么您主要测量的是此跟踪时间加上 Python 开销。
考虑以下概念性代码:
# Assume 'a' and 'b' are XLA tensors
start_time = time.perf_counter()
# This operation is recorded in PyTorch/XLA's graph
result = torch.matmul(a, b)
# ❌❌❌ !!! INCORRECT PROFILING: compilation and execution are deferred !!! ❌❌❌
end_time = time.perf_counter()
elapsed_time = end_time - start_time
这里的 elapsed_time
主要反映了 PyTorch/XLA 跟踪矩阵乘法运算所需的时间。XLA 设备上的实际矩阵乘法及其编译尚未开始。
测量端到端性能¶
要正确分析代码在 XLA 设备上的性能,您必须确保您的计时包含主机端编译和设备执行。这包括:
确保跟踪的计算图已被编译(如果是第一次看到该图或图已更改),并发送到设备执行。
确保 Python 脚本在获取最终时间戳之前等待 XLA 设备完成所有分配的计算。
以下概念性代码使用 torch_xla.sync(wait=True)
举例说明了这一点:
# Assume 'a' and 'b' are XLA tensors
# -- Warm-up Iteration begin ---
# The first execution of a new graph will include compilation time, as
# PyTorch/XLA translates the graph into optimized device code. To isolate the
# steady-state device execution time for consistent benchmarking, we perform a
# "warm-up" run.
_ = torch.matmul(a, b) # The result isn't needed, just triggering the op
torch_xla.sync(wait=True)
# -- Warm-up Iteration end ---
# ✅✅✅ CORRECT PROFILING
# Measure the steady-state execution time, which should exclude
# most of the initial compilation overhead.
start_time = time.perf_counter()
result = torch.matmul(a, b)
# Explicitly wait for the XLA device to finish.
torch_xla.sync(wait=True)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
触发执行并确保完成¶
有几种机制可以触发图执行和/或确保完成:
torch_xla.sync(wait=True)
:这是最直接的基准测试方法。它确保所有待处理的 XLA 操作都已启动,并且至关重要的是,它会阻塞 Python 脚本直到设备完成。Data Access/Transfer
(数据访问/传输):像tensor.cpu()
、tensor.item()
或打印 XLA 张量这样的操作需要实际数据。为了提供数据,PyTorch/XLA 必须执行生成该张量的图并等待其完成。torch_xla.core.xla_model.optimizer_step(optimizer)
:减少梯度,应用优化器更新,并有条件地通过其 barrier 参数(默认为 False,因为数据加载器通常会处理同步)触发torch_xla.sync
。torch_xla.core.xla_model.unlazy(tensors)
:阻塞直到指定的张量被具体化。
案例研究:使用 torch_xla.sync
正确分析循环¶
一个常见场景涉及循环,例如在模型训练中,其中使用了 torch_xla.sync
。考虑此结构:
def run_model():
#... XLA tensor operations...
pass
start_loop_time = time.perf_counter()
for step in range(num_steps):
run_model() # Operations are traced
torch_xla.sync() # Graph for this step is submitted for execution
# ❌❌❌ !!! INCORRECT PROFILING APPROACH FOR TOTAL TIME !!! ❌❌❌
end_loop_time = time.perf_counter()
elapsed_loop_time = end_loop_time - start_loop_time
在这种情况下,elapsed_loop_time
主要测量累积的主机端时间。这包括:
每次迭代中
run_model()
所花费的时间(大部分是跟踪)。每次迭代中
torch_xla.sync
所花费的时间,用于触发主机端编译(如果图是新的或已更改),并将该步骤的图分派到 XLA 设备执行。
至关重要的是,torch_xla.sync()
提交的图是异步运行的:Python 循环在设备仍在执行当前或之前步骤的执行时继续跟踪下一步。因此,如果设备工作落后于 Python 循环,elapsed_loop_time
不能保证包含所有 num_steps
的完整设备执行时间。
为了测量总循环时间(包括所有设备执行),必须在循环之后、获取最终时间戳之前添加 torch_xla.sync(wait=True)
。
start_loop_time = time.perf_counter()
for step in range(num_steps):
run_model_step()
torch_xla.sync()
# ✅✅✅ CORRECT PROFILING: Wait for ALL steps to complete on the device.
torch_xla.sync(wait=True)
end_loop_time = time.perf_counter()
elapsed_loop_time = end_loop_time - start_loop_time