• 文档 >
  • PyTorch XLA 中的 TorchDynamo 集成
快捷方式

TorchDynamo 在 PyTorch XLA 中的集成

TorchDynamo 是一个 Python 级别的 JIT 编译器,旨在加快未经修改的 PyTorch 程序的速度。它为编译器后端提供了干净的 API 来挂钩,其最大的特点是在 Python 字节码执行之前动态修改它。在 pytorch/xla 2.0 版本中,PyTorch/XLA 为 TorchDynamo 提供了用于推理和训练的实验性后端。

XLA 桥接的工作方式是,当 Dynamo 识别出模型模式时,它会提供一个 TorchFX 图,而 PyTorch/XLA 将使用现有的 Lazy Tensor 技术来编译 FX 图并返回编译后的函数。

集成

目前通过向 torch.compile 添加 backend='openxla' 参数来支持 PyTorch/XLA 和 Dynamo。例如

import torch
import torch_xla.core.xla_model as xm

def add(a, b):
  a_xla = a.to(xm.xla_device())
  b_xla = b.to(xm.xla_device())
  return a_xla + b_xla

compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))

推理

这是一个使用 torch.compile 运行 resnet18 的小型代码示例

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
    xla_resnet18, backend='openxla')
  for data, _ in loader:
    with torch.no_grad():
      output = dynamo_resnet18(data)

使用 torch.compile 时,您会发现 PyTorch/XLA 在 init 期间仅跟踪 resnet18 模型一次,并在每次调用 dynamo_resnet18 时执行编译后的二进制文件,而不是每次都跟踪模型。这是在 Cloud TPU v4-8 上使用 torch bench 进行推理速度分析,以比较 Dynamo 和 Lazy。

模型 加速
resnet18 2.59
resnet50 2.64
resnext50_32x4d 1.91
alexnet 1.28
mobilenet_v2 18.62
mnasnet1_0 2.68
vgg16 1.33
BERT_pytorch 7.49
squeezenet1_1 2.29
timm_vision_transformer 3.52
几何平均 3.04

训练

PyTorch/XLA 也支持 Dynamo 进行训练,但这还是实验性的,我们正在与 PyTorch 编译器团队合作迭代实现。这是一个使用 torch.compile 训练 resnet18 的示例

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target, optimizer):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  optimizer.step()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='openxla')
  for data, target in loader:
    xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
    output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)

我们期望与 Lazy Tensor 相比,每个训练步骤提取和执行 3 个图,而不是 1 个图。这是在 Cloud TPU v4-8 上使用 torch bench 进行训练速度分析,以比较 Dynamo 和 Lazy。

模型 加速
resnet50 1.33
resnet18 1.33
BERT_pytorch 3.07
resnext50_32x4d 1.43
alexnet 1.12
mobilenet_v2 1.4
mnasnet1_0 1.19
vgg16 0.81
timm_vision_transformer 1.87
squeezenet1_1 1.41
几何平均 1.41

注意: 我们对每个模型的 fwd 和 bwd 运行单个步骤,然后收集端到端时间。在实际应用中,我们会在每个训练作业中运行多个步骤,这样可以轻松地隐藏执行中的跟踪成本(因为它是异步的)。在这种情况下,Lazy Tensor 的性能会好得多。

功能差距

有一个差距是我们想指出的,它阻碍了我们在更大规模模型上使用 TorchDynamo。

TorchDynamo 会将前向和后向分别跟踪到单独的图中。对于 PyTorch/XLA 来说,让 XLA 编译器将整个步骤视为一个图以获得最佳速度优化至关重要。每次启动设备执行都有一个固定的开销,这使得每个训练步骤执行多个图不太理想。

与 Lazy Tensor 相比,这个差距使得它在实际训练用例中的效率较低,尤其是在训练中跟踪成本可以与执行成本重叠时。

要点

TorchDynamo 为编译器后端提供了一种非常有前途的方式,可以隐藏用户复杂性,并轻松以图形格式检索模型代码。与 PyTorch/XLA 传统的提取图的 Lazy Tensor 方法相比,TorchDynamo 可以跳过每次迭代的图跟踪,从而提供更好的推理响应时间。

PyTorch/XLA 支持的大多数模型在运行推理时,通过新的 dynamo-xla 桥接都实现了显著的加速。我们的社区正在努力扩展支持的模型集。关于上述训练功能差距,PyTorch/XLA 社区非常期待在即将进行的开发工作中改进训练差距。该团队将继续大力投入 TorchDynamo 并与上游合作,以完善训练解决方案。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源