Eager 模式 + Compile API¶
本文档将介绍如何使用 PyTorch/XLA 的新实验性 eager
模式和 compile
API。目标是使 PyTorch/XLA 的体验更符合原生 PyTorch,并简化开发过程。
目前 PyTorch/XLA 默认在 LazyTensor 跟踪模式下运行。在以下代码中
import torch
import torch_xla
import torchvision
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)
# model tracing
res = model(input)
# model execution, same as `xm.mark_step`
torch_xla.sync()
实际的模型编译和设备执行发生在调用 torch_xla.sync
时。这种方法有几个缺点。
用户经常对框架何时进行跟踪和何时执行感到困惑。
非核心模型代码(例如数据预处理)经常会生成一些小的待执行操作,这些操作会泄露到主图(step 函数)中并导致重新编译。整个图的重新编译通常成本很高。
很难调试何时/为何会发生重新编译。
为了缓解上述问题,我们希望引入新的 eager 和 compile 用户体验。
基本用法¶
import torch
import torch_xla
import torchvision
# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
# Mark the function to be compiled
compiled_model = torch_xla.compile(model)
input = torch.randn(64, 3, 224, 224).to(device)
# Compilation and execution happens right away.
res = compiled_model(input)
请注意
目前用户必须通过
torch_xla.experimental.eager_mode(True)
手动启用 eager 模式。希望被编译的代码区域应被
torch_xla.compile
包裹。
torch_xla.compile
的实现实际上相当直接,它在进入目标函数时禁用 eager 模式并开始跟踪。它会在目标函数返回时调用 `torch_xla.sync()` 并重新启用 eager 模式。与现有的 `mark_step/sync` 方法相比,您可以使用 `eager` + `compile` API 获得相同的性能。
推理¶
torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")
对于推理,建议使用 `torch.compile` 而不是 `torch_xla.compile`,以减少跟踪开销。
训练¶
torch_xla.experimental.eager_mode(True)
def step_fn(model, data, target, loss_fn, optimizer):
optimizer.zero_grad()
logits = model(data)
loss = loss_fn(logits, target)
loss.backward()
optimizer.step()
return loss
step_fn = torch_xla.compile(step_fn)
在训练中,我们要求用户将 `step_fn` 分离出来,因为通常最好将模型的 forward、backward 和优化器一起编译。长期目标也是为训练使用 `torch.compile`,但目前我们建议用户出于性能原因使用 `torch_xla.compile`。
基准测试¶
我在一个 v4-8 的芯片上,使用伪数据训练了一个 2 层仅解码器模型(基本上就是一个 llama2),进行了 300 步。以下是我观察到的数字。
模式 token/s
跟踪模式(基线) 147 Eager 模式 65 Eager + torch_xla compile 147
: Eager 模式基准测试
对于仅解码器模型,Eager 模式的性能达到了完全编译模型的约 45%。有关更多信息,请参阅 train_decoder_only_base.py 和 eager 示例。请注意,Eager 模式的性能高度依赖于模型。当我尝试运行 resnet50 时,Eager 模式的性能约为编译模式的 1%。我们不期望用户使用 Eager 模式来执行主要的训练循环。Eager 模式用于处理训练/推理逻辑的非核心部分(数据预处理、随机数生成等)或用于调试。