• 文档 >
  • PyTorch/XLA 中的跟踪时间与执行时间
快捷方式

PyTorch/XLA 中的 Tracing 时间 vs. Execution 时间

在使用 PyTorch/XLA 时,重要的是要理解 XLA 张量上的操作通常不会像在 CPU 或 CUDA 设备上的标准 PyTorch 张量(在“急切模式”下运行)那样立即执行。PyTorch/XLA 采用“延迟执行”模型。这意味着当你使用 XLA 张量编写 PyTorch 代码时,你主要是在定义或跟踪计算图。计算图的编译及其在设备上的后续执行会被推迟到特定的触发点。

这导致需要考虑两种不同的“时间”:

  1. Host-Side Time(主机端时间):CPU(主机)准备计算的期间。这包括:

    • Tracing Time(跟踪时间):PyTorch/XLA 记录操作并构建计算图的期间。

    • Compilation Time(编译时间):主机端 XLA 编译器将跟踪的图转换为优化的设备代码所需的时间。这在新图的第一次执行或图发生变化时最为显著。

  2. Device Time(设备时间):这主要是 Execution Time(执行时间),即 XLA 设备(例如 TPU)运行编译后的代码的期间。

说明常见陷阱:仅测量 Tracing 时间

当您使用 XLA 张量(例如,TPU 上的张量)编写 PyTorch 代码时,PyTorch/XLA 不会立即在设备上执行每个操作。它会跟踪这些操作,将它们添加到内部计算图中。如果您测量仅执行 XLA 操作而没有显式等待设备的指令的代码的持续时间,那么您主要测量的是此跟踪时间加上 Python 开销。

考虑以下概念性代码:

# Assume 'a' and 'b' are XLA tensors
start_time = time.perf_counter()

# This operation is recorded in PyTorch/XLA's graph
result = torch.matmul(a, b)

# ❌❌❌ !!! INCORRECT PROFILING: compilation and execution are deferred !!! ❌❌❌
end_time = time.perf_counter()
elapsed_time = end_time - start_time

这里的 elapsed_time 主要反映了 PyTorch/XLA 跟踪矩阵乘法运算所需的时间。XLA 设备上的实际矩阵乘法及其编译尚未开始。

测量端到端性能

要正确分析代码在 XLA 设备上的性能,您必须确保您的计时包含主机端编译和设备执行。这包括:

  1. 确保跟踪的计算图已被编译(如果是第一次看到该图或图已更改),并发送到设备执行。

  2. 确保 Python 脚本在获取最终时间戳之前等待 XLA 设备完成所有分配的计算。

以下概念性代码使用 torch_xla.sync(wait=True) 举例说明了这一点:

# Assume 'a' and 'b' are XLA tensors

# -- Warm-up Iteration begin ---

# The first execution of a new graph will include compilation time, as
# PyTorch/XLA translates the graph into optimized device code. To isolate the
# steady-state device execution time for consistent benchmarking, we perform a
# "warm-up" run.
_ = torch.matmul(a, b) # The result isn't needed, just triggering the op
torch_xla.sync(wait=True)

# -- Warm-up Iteration end ---

# ✅✅✅ CORRECT PROFILING
# Measure the steady-state execution time, which should exclude
# most of the initial compilation overhead.
start_time = time.perf_counter()

result = torch.matmul(a, b)

# Explicitly wait for the XLA device to finish.
torch_xla.sync(wait=True)

end_time = time.perf_counter()
elapsed_time = end_time - start_time

触发执行并确保完成

有几种机制可以触发图执行和/或确保完成:

  1. torch_xla.sync(wait=True):这是最直接的基准测试方法。它确保所有待处理的 XLA 操作都已启动,并且至关重要的是,它会阻塞 Python 脚本直到设备完成。

  2. Data Access/Transfer(数据访问/传输):像 tensor.cpu()tensor.item() 或打印 XLA 张量这样的操作需要实际数据。为了提供数据,PyTorch/XLA 必须执行生成该张量的图并等待其完成。

  3. torch_xla.core.xla_model.optimizer_step(optimizer):减少梯度,应用优化器更新,并有条件地通过其 barrier 参数(默认为 False,因为数据加载器通常会处理同步)触发 torch_xla.sync

  4. torch_xla.core.xla_model.unlazy(tensors):阻塞直到指定的张量被具体化。

案例研究:使用 torch_xla.sync 正确分析循环

一个常见场景涉及循环,例如在模型训练中,其中使用了 torch_xla.sync。考虑此结构:

def run_model():
    #... XLA tensor operations...
    pass

start_loop_time = time.perf_counter()
for step in range(num_steps):
  run_model()  # Operations are traced
  torch_xla.sync()    # Graph for this step is submitted for execution

# ❌❌❌ !!! INCORRECT PROFILING APPROACH FOR TOTAL TIME !!! ❌❌❌
end_loop_time = time.perf_counter()
elapsed_loop_time = end_loop_time - start_loop_time

在这种情况下,elapsed_loop_time 主要测量累积的主机端时间。这包括:

  1. 每次迭代中 run_model() 所花费的时间(大部分是跟踪)。

  2. 每次迭代中 torch_xla.sync 所花费的时间,用于触发主机端编译(如果图是新的或已更改),并将该步骤的图分派到 XLA 设备执行。

至关重要的是,torch_xla.sync() 提交的图是异步运行的:Python 循环在设备仍在执行当前或之前步骤的执行时继续跟踪下一步。因此,如果设备工作落后于 Python 循环,elapsed_loop_time 不能保证包含所有 num_steps 的完整设备执行时间。

为了测量总循环时间(包括所有设备执行),必须在循环之后、获取最终时间戳之前添加 torch_xla.sync(wait=True)

start_loop_time = time.perf_counter()
for step in range(num_steps):
  run_model_step()
  torch_xla.sync()

# ✅✅✅ CORRECT PROFILING: Wait for ALL steps to complete on the device.
torch_xla.sync(wait=True)

end_loop_time = time.perf_counter()
elapsed_loop_time = end_loop_time - start_loop_time

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源