PyTorch/XLA 编译 API 及其与 Eager 模式的交互。¶
概述¶
PyTorch/XLA 将 PyTorch 与 XLA 编译器集成,以优化跨各种硬件加速器的深度学习工作负载。目前,PyTorch/XLA 默认使用 LazyTensor 跟踪模式,在这种模式下,操作会被记录到计算图中,以便进行延迟编译和执行(通过 torch_xla.sync()
触发),如下面的代码所示。
import torch
import torch_xla
import torchvision
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)
# model tracing
res = model(input)
# model execution
torch_xla.sync()
虽然这种方法可以实现性能优化,但它带来了显著的用户体验挑战。
LazyTensor 模式的挑战¶
歧义:开发人员难以区分跟踪和执行阶段,这使得开发和调试变得复杂。
重新编译开销:每当捕获图的任何部分发生更改时,
torch_xla.sync()
都会重新编译整个图。非核心操作(例如数据预处理)的更改因此会触发昂贵的重新编译。调试困难:由于图构建过程的不透明性,识别重新编译的原因很困难。
Eager 模式和 torch_xla.compile
¶
为了解决这些问题,PyTorch/XLA 引入了一种实验性的 Eager 模式(通过 torch_xla.experimental.eager_mode(True)
启用)和 torch_xla.compile
API。这种转变使 PyTorch/XLA 更接近原生 PyTorch,优先考虑开发人员体验,同时保持性能。Eager 模式可能会成为未来版本的默认设置。
Eager 模式:立即执行操作,增强灵活性和调试能力,但会以性能为代价。
torch_xla.compile:一个装饰器或包装器,它在 Eager 上下文中显式标记代码(例如模型或函数)以进行 XLA 编译,从而提供清晰的界限和即时反馈。
请注意,torch_xla.compile
即使在 Eager 模式之外也具有独立的用途,它提供了诸如通过将数据加载操作捕获到单独的图中来防止它们泄漏到训练循环图中的好处,并且在指定 full_graph=True
时捕获意外的图中断。
torch_xla.compile
的工作原理¶
让我们看一下 torch_xla.compile
的基本用法。
import torch
import torch_xla
import torchvision
# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
# Mark the function to be compiled
compiled_model = torch_xla.compile(model)
input = torch.randn(64, 3, 224, 224).to(device)
# Compilation and execution happens right away.
res = compiled_model(input)
其中 torch_xla.compile
的实现可以概括如下:
禁用 Eager 模式:暂时切换到跟踪以构建计算图。
跟踪操作:记录操作以供 XLA 优化。
编译和执行:通过内部
torch_xla.sync()
调用触发编译和执行。重新启用 Eager 模式:编译后恢复 Eager 执行。
这种“Eager-到-Lazy-到-Eager”的过渡抽象了同步复杂性,平衡了灵活性和性能。
torch_xla.compile
与 torch.compile
¶
PyTorch 生态系统提供了多种编译 API,了解它们的不同作用,尤其是在 PyTorch/XLA 中,对于实现最佳性能和开发至关重要。
torch_xla.compile
针对 PyTorch/XLA 训练工作流进行了优化。它旨在与 XLA 后端高效配合以进行迭代训练,是编译训练循环的推荐 API,因为它具有已观察到的性能优势。最佳实践是将完整的训练步骤,例如前向传播、损失计算、后向传播和优化器步骤,封装在step_fn
中,然后编译此函数。
torch_xla.experimental.eager_mode(True)
def step_fn(model, data, target, loss_fn, optimizer):
optimizer.zero_grad()
logits = model(data)
loss = loss_fn(logits, target)
loss.backward()
optimizer.step()
return loss
step_fn = torch_xla.compile(step_fn)
torch.compile
是 PyTorch 的通用编译 API,旨在跨各种后端加速 PyTorch 模型。对于 PyTorch/XLA,它使用openxla
后端。我们推荐torch.compile
用于 PyTorch/XLA 推理,因为它降低了跟踪开销,从而实现了更有效的静态推理图。要将其与 XLA 一起使用,只需指定backend="openxla"
。
torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")
长远目标是让 torch.compile
成为 XLA 上训练和推理的单一编译 API。
性能基准¶
为了量化 torch_xla.compile
和 Eager 模式的性能影响,在特定条件下进行了基准测试。基准测试使用了类似于 Llama2 的 2 层仅解码器模型,并使用假数据进行了训练。训练过程在单个 v4-8 TPU 芯片上进行了 300 步。观察到的性能(以 tokens/s 为单位)清楚地说明了不同执行模式的影响。
模式 |
token/s |
---|---|
跟踪模式(基线) |
147 |
Eager 模式 |
65 |
Eager + |
147 |
Eager 模式配合 torch_xla.compile
在 147 tokens/s 的速度下达到了传统 LazyTensor 跟踪模式的性能,在不损失性能的情况下提供了更好的用户体验。
纯 Eager 模式的性能取决于模型;对于仅解码器模型,它达到了完全编译模型的 ~45% 的性能。然而,对于 ResNet50,纯 Eager 模式的性能要差得多(约为编译模式的 1%)。有关更多信息,请参阅 train_decoder_only_base.py 和 eager 示例。这种可变的开销意味着纯 Eager 模式不适用于主要的训练或推理循环。它的用途在于非核心任务,如数据预处理、随机数生成、自定义实用程序或调试,在这些任务中,即时执行比吞吐量更受重视。