PyTorch on XLA Devices¶
PyTorch 可在 XLA 设备(如 TPU)上运行,通过 torch_xla 包。本文档介绍如何在这些设备上运行您的模型。
创建 XLA Tensor¶
PyTorch/XLA 为 PyTorch 添加了一个新的 xla
设备类型。此设备类型的工作方式与其他 PyTorch 设备类型相同。例如,以下是如何创建和打印 XLA Tensor:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device='xla')
print(t.device)
print(t)
这段代码应该看起来很熟悉。PyTorch/XLA 使用与常规 PyTorch 相同的接口,并进行了一些补充。导入 torch_xla
会初始化 PyTorch/XLA,而 torch_xla.device()
会返回当前的 XLA 设备。根据您的环境,这可能是 CPU 或 TPU。
XLA Tensors 是 PyTorch Tensors¶
PyTorch 操作可以在 XLA Tensor 上执行,就像在 CPU 或 CUDA Tensor 上一样。
例如,XLA Tensor 可以相加:
t0 = torch.randn(2, 2, device='xla')
t1 = torch.randn(2, 2, device='xla')
print(t0 + t1)
或者矩阵相乘:
print(t0.mm(t1))
或者与神经网络模块一起使用:
l_in = torch.randn(10, device='xla')
linear = torch.nn.Linear(10, 20).to('xla')
l_out = linear(l_in)
print(l_out)
与其他设备类型一样,XLA Tensor 只能与同一设备上的其他 XLA Tensor 一起工作。因此,像这样的代码:
l_in = torch.randn(10, device='xla')
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
会引发错误,因为 torch.nn.Linear
模块位于 CPU 上。
在 XLA 设备上运行模型¶
构建一个新的 PyTorch 网络或转换现有网络以在 XLA 设备上运行,只需要几行 XLA 特定的代码。以下代码片段在单个设备和使用 XLA 多进程的多个设备上运行时突出了这些行。
在单个 XLA 设备上运行¶
以下代码片段显示了一个在单个 XLA 设备上训练的网络:
import torch_xla.core.xla_model as xm
device = torch_xla.device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
torch_xla.sync()
这段代码片段突出了将模型切换到 XLA 上运行的便捷性。模型定义、数据加载器、优化器和训练循环可以在任何设备上工作。唯一 XLA 特定的代码是获取 XLA 设备和具体化 Tensor 的几行。在每次训练迭代结束时调用 torch_xla.sync()
会导致 XLA 执行其当前图并更新模型的参数。有关 XLA 如何创建图和运行操作的更多信息,请参阅 XLA Tensor 深度解析。
使用多进程在多个 XLA 设备上运行¶
PyTorch/XLA 使通过在多个 XLA 设备上运行来加速训练变得容易。以下代码片段展示了如何实现:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
def _mp_fn(index):
device = torch_xla.device()
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in mp_device_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=())
这个多设备代码片段与之前的单设备代码片段有三个不同之处。让我们一一介绍。
torch_xla.launch()
创建运行 XLA 设备的进程。
此函数是 multithreading spawn 的包装器,允许用户也使用 torchrun 命令行运行脚本。每个进程只能访问分配给当前进程的设备。例如,在 TPU v4-8 上,将启动 4 个进程,每个进程拥有一个 TPU 设备。
请注意,如果在每个进程中打印
torch_xla.device()
,您将看到所有设备上显示xla:0
。这是因为每个进程只能看到一个设备。这并不意味着多进程功能不起作用。由于存在#devices/2
个进程,并且每个进程有 2 个线程(有关更多详细信息,请检查此 文档),因此在 TPU v2 和 TPU v3 上,唯一的执行是与 PJRT 运行时一起进行的。
MpDeviceLoader
将训练数据加载到每个设备上。
MpDeviceLoader
可以包装在 torch dataloader 上。它可以预加载数据到设备,并使数据加载与设备执行重叠,以提高性能。MpDeviceLoader
还会为您在每batches_per_execution
(默认为 1)批次生成时调用torch_xla.sync()
。
xm.optimizer_step(optimizer)
汇总设备之间的梯度并发出 XLA 设备步进计算。
它基本上是
all_reduce_gradients
+optimizer.step()
+torch_xla.sync()
,并返回已减少的损失。
模型定义、优化器定义和训练循环保持不变。
注意: 重要的是要注意,在使用多进程时,用户只能从
torch_xla.launch()
的目标函数(或任何以torch_xla.launch()
为父项的调用堆栈中的函数)内部开始检索和访问 XLA 设备。
有关在多个 XLA 设备上使用多进程训练网络的更多信息,请参阅 完整的 multiprocessing 示例。
在 TPU Pod 上运行¶
不同加速器的多主机设置可能差异很大。本文档将讨论多主机训练中与设备无关的部分,并以 TPU + PJRT 运行时(目前在 1.13 和 2.x 版本中可用)为例。
在开始之前,请先查看我们的用户指南 这里,它将解释一些 Google Cloud 的基础知识,例如如何使用 gcloud
命令以及如何设置您的项目。您也可以查看 这里 获取所有 Cloud TPU 操作指南。本文档将重点介绍设置的 PyTorch/XLA 视角。
假设您在上一个部分中看到的 mnist 示例位于 train_mnist_xla.py
文件中。如果这是单主机多设备训练,您将 SSH 到 TPUVM 并运行如下命令:
PJRT_DEVICE=TPU python3 train_mnist_xla.py
现在,为了在 TPU v4-16(有两个主机,每个主机有 4 个 TPU 设备)上运行相同的模型,您需要:- 确保每个主机都能访问训练脚本和训练数据。这通常通过使用 gcloud scp
命令或 gcloud ssh
命令将训练脚本复制到所有主机来完成。- 在所有主机上同时运行相同的训练命令。
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"
上面的 gcloud ssh
命令将 SSH 到 TPUVM Pod 中的所有主机,并同时运行相同的命令。
注意: 您需要在 TPUVM 虚拟机外部运行上述
gcloud
命令。
模型代码和训练脚本对于多进程训练和多主机训练是相同的。PyTorch/XLA 和底层基础设施将确保每个设备都了解全局拓扑以及每个设备的本地和全局序数。跨设备通信将在所有设备之间发生,而不是仅限于本地设备。
有关 PJRT 运行时的更多详细信息以及如何在 Pod 上运行它,请参阅此 文档。有关 PyTorch/XLA 和 TPU Pod 的更多信息,以及在 TPU Pod 上运行 resnet50 和 fakedata 的完整指南,请参阅此 指南。
XLA Tensor 深度解析¶
使用 XLA Tensor 和设备只需更改几行代码。但即使 XLA Tensor 的行为与 CPU 和 CUDA Tensor 非常相似,它们的内部机制也不同。本节将介绍 XLA Tensor 的独特性。
XLA Tensors 是惰性的¶
CPU 和 CUDA Tensor 会立即或即时执行操作。而 XLA Tensor 是惰性的。它们会在图中记录操作,直到需要结果为止。这种延迟执行允许 XLA 进行优化。例如,由多个独立操作组成的图可能会合并为单个优化操作。
惰性执行通常对调用者是不可见的。PyTorch/XLA 会自动构建图,将它们发送到 XLA 设备,并在 XLA 设备和 CPU 之间复制数据时进行同步。在获取优化器步骤时插入屏障会显式同步 CPU 和 XLA 设备。有关我们惰性 Tensor 设计的更多信息,您可以阅读 这篇论文。
内存布局¶
XLA Tensor 的内部数据表示对用户是不透明的。它们不暴露其存储,并且始终显得是连续的,这与 CPU 和 CUDA Tensor 不同。这允许 XLA 调整 Tensor 的内存布局以获得更好的性能。
将 XLA Tensor 移动到 CPU 和从 CPU 移动¶
XLA Tensor 可以从 CPU 移动到 XLA 设备,也可以从 XLA 设备移动到 CPU。如果移动的是一个视图,那么它所查看的数据也会被复制到另一个设备,并且视图关系不会被保留。换句话说,一旦数据被复制到另一个设备,它就与其前一个设备或上面的任何 Tensor 没有关系。再次强调,根据您的代码操作方式,理解和适应这种过渡可能很重要。
保存和加载 XLA Tensors¶
保存 XLA Tensor 之前应将其移动到 CPU,如下面的代码片段所示:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = torch_xla.device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
这允许您将加载的 Tensor 放在任何可用的设备上,而不仅仅是初始化它们的设备。
根据上面关于将 XLA Tensor 移动到 CPU 的说明,在处理视图时必须小心。建议不要保存视图,而是在 Tensor 已加载并移动到其目标设备(s) 后重新创建它们。
提供了一个实用 API,通过移动到 CPU 来处理数据保存:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
xm.save(model.state_dict(), path)
对于多个设备,上面的 API 只会保存主设备序数(0)的数据。
在内存相对于模型参数大小有限的情况下,提供了一个可以减少主机内存占用的 API:
import torch_xla.utils.serialization as xser
xser.save(model.state_dict(), path)
此 API 一次将 XLA Tensor 流式传输到 CPU,从而减少主机内存使用量,但需要匹配的加载 API 来恢复。
import torch_xla.utils.serialization as xser
state_dict = xser.load(path)
model.load_state_dict(state_dict)
直接保存 XLA Tensor 是可能的,但不推荐。XLA Tensor 始终会加载回它们被保存的设备,如果该设备不可用,则加载会失败。PyTorch/XLA 与所有 PyTorch 一样,处于积极的开发中,此行为将来可能会发生变化。
AOT(提前)跟踪期间的意外 Tensor 具体化¶
虽然 Tensor 具体化对于 JIT 工作流程是正常的,但在跟踪推理(即 AWS Neuron 中的 AOT 模型跟踪)期间,它是不期望的。在使用跟踪推理时,开发人员可能会遇到 Tensor 具体化,这会导致图基于示例输入 Tensor 值进行编译,并产生意外的程序行为。因此,我们需要利用 PyTorch/XLA 的调试标志来识别意外 Tensor 具体化何时发生,并进行适当的代码更改以避免 Tensor 具体化。
当 Tensor 值在模型编译(跟踪推理)期间被评估时,会发生一个常见问题。考虑以下示例:
def forward(self, tensor):
if tensor[0] == 1:
return tensor
else:
return tensor * 2
虽然此代码可以编译和运行,但可能会导致意外行为,因为:
在跟踪期间访问 Tensor 值(
tensor[0]
)。生成的图基于跟踪期间可用的 Tensor 值进行固定。
开发人员可能会错误地认为条件将在推理期间动态评估。
上述代码的解决方案是利用下面的调试标志来捕获问题并修改代码。一个例子是通过模型配置传递标志:
请参阅没有 Tensor 具体化的更新代码:
class TestModel(torch.nn.Module):
def __init__(self, flag=1):
super().__init__()
# the flag should be pre-determined based on the model configuration
# it should not be an input of the model during runtime
self.flag = flag
def forward(self, tensor):
if self.flag:
return tensor
else:
return tensor * 2
调试标志¶
为了帮助捕获 Tensor 具体化问题,PyTorch/XLA 提供了两种有用的方法:
为 Tensor 具体化启用警告消息
import os
os.environ['PT_XLA_DEBUG_LEVEL'] = '2'
禁用图执行以在开发过程中捕获问题
import torch_xla
torch_xla._XLAC._set_allow_execution(False)
建议¶
在开发过程中使用这些标志有助于在开发周期的早期识别潜在问题。建议的方法是:
在初始开发过程中使用
PT_XLA_DEBUG_LEVEL=2
来识别潜在的具体化点。当您想确保在跟踪期间不发生 Tensor 具体化时,应用
_set_allow_execution(False)
。当看到与 Tensor 具体化相关的警告或错误时,请检查代码路径并进行适当的修改。上面的示例将标志移到了
__init__
函数中,该函数在运行时不依赖于模型输入。
有关更详细的调试信息,请参阅 XLA 故障排除。
编译缓存¶
XLA 编译器将跟踪的 HLO 转换为在设备上运行的可执行文件。编译可能非常耗时。如果 HLO 在不同执行之间不发生变化,则编译结果可以持久化到磁盘以供重用,从而显著缩短开发迭代时间。
注意
如果 HLO 在执行之间发生变化,仍然会发生重新编译。
当
torch_xla
的版本发生变化时,会发生重新编译(以便我们可以使用最新的编译器生成可执行文件)。
这目前是一个选择加入的 API,必须在执行任何计算之前激活。初始化通过 initialize_cache
API 进行:
import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)
这将初始化指定路径下的持久编译缓存。可以readonly
参数来控制工作进程是否能够写入缓存,这在共享缓存挂载用于 SPMD 工作负载时可能很有用。
如果您想在多进程训练(使用 torch_xla.launch
或 xmp.spawn
)中使用持久编译缓存,您应该为不同的进程使用不同的路径。
def _mp_fn(index):
# cache init needs to happens inside the mp_fn.
xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False)
....
if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=())
如果您没有访问 index
的权限,您可以使用 xr.global_ordinal()
。请查看 此处 的可运行示例。
延伸阅读¶
更多文档可在 PyTorch/XLA 仓库 中找到。有关在 TPU 上运行网络的更多示例,请参阅 此处。