TorchDynamo 集成¶
TorchDynamo 是一个 Python 级别的 JIT 编译器,旨在提高未经修改的 PyTorch 程序的性能。它为编译器后端提供了一个干净的 API 来进行挂载,其最大的特点是在 Python 字节码执行前动态修改它。在 2.0 版本中,PyTorch/XLA 为 TorchDynamo 提供了实验性的后端,支持推理和训练。
XLA Bridge 的工作方式是,当 Dynamo 识别出模型模式时,它会提供一个 TorchFX 图,然后 PyTorch/XLA 使用现有的 Lazy Tensor 技术来编译 FX 图并返回编译后的函数。
集成¶
目前,PyTorch/XLA 和 Dynamo 的支持方式是将 backend='openxla'
参数添加到 torch.compile
中。例如:
import torch
import torch_xla.core.xla_model as xm
def add(a, b):
a_xla = a.to('xla')
b_xla = b.to('xla')
return a_xla + b_xla
compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))
推理¶
这是一个使用 torch.compile
运行 resnet18 的小型代码示例。
import torch
import torchvision
import torch_xla.core.xla_model as xm
def eval_model(loader):
device = torch_xla.device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(
xla_resnet18, backend='openxla')
for data, _ in loader:
with torch.no_grad():
output = dynamo_resnet18(data)
使用 torch.compile
时,您会发现 PyTorch/XLA 在初始化时只跟踪 ResNet-18 模型一次,并在每次调用 dynamo_resnet18
时执行编译后的二进制文件,而不是每次都跟踪模型。以下是使用 Cloud TPU v4-8 上的 torch bench 进行推理速度分析,以比较 Dynamo 和 Lazy 的性能。
模型 | 加速比 |
---|---|
resnet18 | 2.59 |
resnet50 | 2.64 |
resnext50_32x4d | 1.91 |
alexnet | 1.28 |
mobilenet_v2 | 18.62 |
mnasnet1_0 | 2.68 |
vgg16 | 1.33 |
BERT_pytorch | 7.49 |
squeezenet1_1 | 2.29 |
timm_vision_transformer | 3.52 |
几何平均值 | 3.04 |
训练¶
PyTorch/XLA 也支持 Dynamo 进行训练,但仍处于实验阶段,我们正在与 PyTorch Compiler 团队合作迭代实现。以下是使用 torch.compile
训练 ResNet-18 的示例。
import torch
import torchvision
import torch_xla.core.xla_model as xm
def train_model(model, data, target, optimizer):
loss_fn = torch.nn.CrossEntropyLoss()
pred = model(data)
loss = loss_fn(pred, target)
loss.backward()
optimizer.step()
return pred
def train_model_main(loader):
device = torch_xla.device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.train()
dynamo_train_model = torch.compile(
train_model, backend='openxla')
for data, target in loader:
xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)
与使用 Lazy Tensor 相比,我们期望每个训练步骤提取和执行 3 个图,而不是 1 个。以下是使用 Cloud TPU v4-8 上的 torch bench 进行训练速度分析,以比较 Dynamo 和 Lazy 的性能。
模型 | 加速比 |
---|---|
resnet50 | 1.33 |
resnet18 | 1.33 |
BERT_pytorch | 3.07 |
resnext50_32x4d | 1.43 |
alexnet | 1.12 |
mobilenet_v2 | 1.4 |
mnasnet1_0 | 1.19 |
vgg16 | 0.81 |
timm_vision_transformer | 1.87 |
squeezenet1_1 | 1.41 |
几何平均值 | 1.41 |
注意: 我们为每个模型运行一次前向和后向传播,然后收集端到端时间。在实际应用中,我们会在每个训练作业中运行多个步骤,这可以轻松地掩盖执行过程中的跟踪开销(因为它是异步的)。在这种情况下,Lazy Tensor 会获得更好的性能。
功能差距¶
有一个我们想指出的差距,它阻碍了我们在更大规模的模型上使用 TorchDynamo。
TorchDynamo 会将前向和后向传播分别跟踪为独立的图。对于 PyTorch/XLA 来说,让 XLA 编译器将整个训练步骤视为一个图以实现最佳速度优化非常重要。此外,每次启动设备执行都会有固定的开销,这使得每个训练步骤执行多个图变得不太理想。
与 Lazy Tensor 相比,这种差距在实际训练用例中效率较低,尤其是在训练过程中,跟踪开销可以与执行过程重叠。
总结¶
TorchDynamo 为编译器后端提供了一种有前景的方式,可以隐藏复杂性,方便用户检索以图格式表示的模型代码。与 PyTorch/XLA 传统的 Lazy Tensor 图提取方式相比,TorchDynamo 可以跳过每个迭代的图跟踪,从而提供更好的推理响应时间。
大多数 PyTorch/XLA 支持的模型,在运行推理时都看到了使用新的 dynamo-xla 桥接带来的显著加速。我们的社区正在努力扩展支持的模型集。关于上述训练功能差距,PyTorch/XLA 社区非常期待在我们未来的开发工作中改进这些差距。团队将继续在 TorchDynamo 上投入。