PyTorch on XLA Devices¶
PyTorch 在 TPU 等 XLA 设备上运行,使用 torch_xla 包。本文档介绍了如何在这些设备上运行模型。
创建 XLA Tensor¶
PyTorch/XLA 为 PyTorch 添加了一个新的 xla
设备类型。此设备类型与其他 PyTorch 设备类型的工作方式相同。例如,以下是如何创建和打印 XLA Tensor:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
这段代码应该很熟悉。PyTorch/XLA 使用与常规 PyTorch 相同的接口,并进行了一些扩展。导入 torch_xla
会初始化 PyTorch/XLA,而 xm.xla_device()
会返回当前 XLA 设备。根据您的环境,这可能是 CPU 或 TPU。
XLA Tensors 是 PyTorch Tensors¶
PyTorch 操作可以像在 CPU 或 CUDA Tensor 上一样在 XLA Tensor 上执行。
例如,XLA Tensors 可以相加
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
或者进行矩阵乘法
print(t0.mm(t1))
或者与神经网络模块一起使用
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
与其他设备类型一样,XLA Tensors 只能与同一设备上的其他 XLA Tensors 一起工作。因此,类似这样的代码:
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
将引发错误,因为 torch.nn.Linear
模块位于 CPU 上。
在 XLA 设备上运行模型¶
构建新的 PyTorch 网络或将现有网络转换为在 XLA 设备上运行,只需要几行 XLA 特定代码。以下代码片段在单个设备和使用 XLA 多进程的多个设备上运行时突出显示了这些行。
在单个 XLA 设备上运行¶
以下代码片段展示了一个在单个 XLA 设备上训练的网络:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
此代码片段突出了将模型切换到 XLA 上运行的简便性。模型定义、数据加载器、优化器和训练循环可以在任何设备上工作。唯一特定于 XLA 的代码是几行,用于获取 XLA 设备并标记步进。在每次训练迭代结束时调用 xm.mark_step()
会导致 XLA 执行其当前图并更新模型的参数。有关 XLA 如何创建图和运行操作的更多信息,请参阅 XLA Tensor 深入解析。
使用多进程在多个 XLA 设备上运行¶
PyTorch/XLA 可以轻松地通过在多个 XLA 设备上运行来加速训练。以下代码片段展示了如何实现:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
def _mp_fn(index):
device = xm.xla_device()
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in mp_device_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=())
与之前的单设备代码片段相比,此多设备代码片段有三个不同之处。我们逐一进行说明。
torch_xla.launch()
创建运行 XLA 设备的进程。
此函数是多线程 spawn 的包装器,允许用户通过 torchrun 命令行运行脚本。每个进程只能访问分配给当前进程的设备。例如,在 TPU v4-8 上,将启动 4 个进程,每个进程拥有一个 TPU 设备。
请注意,如果在每个进程中打印
xm.xla_device()
,您会在所有设备上看到xla:0
。这是因为每个进程只能看到一个设备。这并不意味着多进程不起作用。在 TPU v2 和 TPU v3 上,仅使用 PJRT 运行时,因为会有#devices/2
个进程,每个进程有 2 个线程(有关更多详细信息,请参阅此 文档)。
MpDeviceLoader
将训练数据加载到每个设备上。
MpDeviceLoader
可以包装 torch dataloader。它可以预加载数据到设备,并重叠数据加载与设备执行以提高性能。MpDeviceLoader
还会为您每batches_per_execution
(默认为 1)批次调用xm.mark_step
。
xm.optimizer_step(optimizer)
在设备之间合并梯度并发出 XLA 设备步进计算。
它基本上是
all_reduce_gradients
+optimizer.step()
+mark_step
,并返回归约后的损失。
模型定义、优化器定义和训练循环保持不变。
注意:重要的是要注意,在使用多进程时,用户只能在
torch_xla.launch()
的目标函数(或以torch_xla.launch()
作为调用堆栈父级的任何函数)中开始检索和访问 XLA 设备。
有关在多个 XLA 设备上进行多进程训练的更多信息,请参阅 完整的 multiprocessing 示例。
在 TPU Pod 上运行¶
不同加速器的多主机设置可能非常不同。本文档将讨论多主机训练中与设备无关的部分,并将以 TPU + PJRT 运行时(目前在 1.13 和 2.x 版本中可用)为例。
在开始之前,请查看我们的用户指南 这里,它将解释一些 Google Cloud 的基础知识,例如如何使用 gcloud
命令以及如何设置您的项目。您还可以查看 这里 获取所有 Cloud TPU 操作指南。本文档将侧重于设置的 PyTorch/XLA 视角。
假设您有上面章节中的 mnist 示例,将其保存在 train_mnist_xla.py
文件中。如果是单主机多设备训练,您将 SSH 到 TPUVM 并运行如下命令:
PJRT_DEVICE=TPU python3 train_mnist_xla.py
现在,为了在 TPU v4-16(具有 2 个主机,每个主机有 4 个 TPU 设备)上运行相同的模型,您需要: - 确保每个主机都可以访问训练脚本和训练数据。这通常通过使用 gcloud scp
命令或 gcloud ssh
命令将训练脚本复制到所有主机来完成。 - 在所有主机上同时运行相同的训练命令。
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"
上面的 gcloud ssh
命令将 SSH 到 TPUVM Pod 中的所有主机,并同时运行相同的命令。
注意:您需要在 TPUVM 虚拟机外部运行上述
gcloud
命令。
模型代码和训练脚本对于多进程训练和多主机训练是相同的。PyTorch/XLA 和底层基础架构将确保每个设备都了解全局拓扑以及每个设备的本地和全局序数。跨设备通信将发生在所有设备上,而不是仅限于本地设备。
有关 PJRT 运行时及其在 Pod 上运行的更多详细信息,请参阅此 文档。有关 PyTorch/XLA 和 TPU Pod 的更多信息,以及在 TPU Pod 上运行 resnet50(使用 fakedata)的完整指南,请参阅此 指南。
XLA Tensor 深入解析¶
使用 XLA Tensor 和设备只需要更改几行代码。但尽管 XLA Tensor 的作用与 CPU 和 CUDA Tensor 非常相似,但它们的内部结构却不同。本节将介绍 XLA Tensor 的独特性。
XLA Tensors 是惰性的¶
CPU 和 CUDA Tensor 会立即或以贪婪方式启动操作。而 XLA Tensor 则是惰性的。它们在图中记录操作,直到需要结果时才执行。通过延迟执行,XLA 可以进行优化。例如,一个包含多个独立操作的图可以合并成一个优化的操作。
惰性执行通常对调用者是不可见的。PyTorch/XLA 会自动构建图,将其发送到 XLA 设备,并在 XLA 设备与 CPU 之间复制数据时进行同步。在优化器步骤中插入屏障会显式同步 CPU 和 XLA 设备。有关我们惰性 Tensor 设计的更多信息,您可以阅读 这篇论文。
内存布局¶
XLA Tensor 的内部数据表示对用户来说是不透明的。它们不暴露其存储,并且总是显得是连续的,这与 CPU 和 CUDA Tensor 不同。这使得 XLA 可以调整 Tensor 的内存布局以获得更好的性能。
在 CPU 和 XLA Tensor 之间移动¶
XLA Tensor 可以从 CPU 移动到 XLA 设备,也可以从 XLA 设备移动到 CPU。如果移动的是一个视图,那么它所查看的数据也会被复制到另一个设备,并且视图关系不会被保留。换句话说,一旦数据被复制到另一个设备,它就与之前的设备或其上的任何 Tensor 没有关系了。同样,根据代码的运行方式,理解并适应这种转换可能很重要。
保存和加载 XLA Tensor¶
在保存 XLA Tensor 之前,应将其移动到 CPU,如下面的代码片段所示:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
这允许您将加载的 Tensor 放在任何可用的设备上,而不仅仅是初始化它们的设备。
根据上面关于将 XLA Tensor 移动到 CPU 的注意事项,在处理视图时需要小心。建议不要保存视图,而是建议在 Tensor 加载并移动到目标设备后重新创建它们。
提供了一个实用 API,通过处理预先将其移动到 CPU 来保存数据:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
xm.save(model.state_dict(), path)
在有多个设备的情况下,上述 API 只会保存主设备序数(0)的数据。
在内存量相对于模型参数大小时受限的情况下,提供了一个可以减少主机内存占用的 API:
import torch_xla.utils.serialization as xser
xser.save(model.state_dict(), path)
此 API 将 XLA Tensor 一次一个地流式传输到 CPU,从而减少了主机内存的使用量,但它需要匹配的加载 API 才能恢复。
import torch_xla.utils.serialization as xser
state_dict = xser.load(path)
model.load_state_dict(state_dict)
直接保存 XLA Tensor 是可能的,但不推荐。XLA Tensor 总是被加载回它们保存的设备,如果该设备不可用,则加载会失败。PyTorch/XLA 与所有 PyTorch 一样,处于积极开发中,此行为将来可能会发生变化。
编译缓存¶
XLA 编译器将跟踪的 HLO 转换为可在设备上运行的可执行文件。编译可能非常耗时,在 HLO 在执行之间不发生变化的情况下,可以将编译结果持久化到磁盘以供重用,从而显著缩短开发迭代时间。
请注意,如果在执行之间 HLO 发生变化,仍然会发生重新编译。
这目前是一个实验性的选择加入 API,必须在执行任何计算之前激活。初始化通过 initialize_cache
API 进行。
import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)
这将初始化位于指定路径的持久编译缓存。readonly
参数可用于控制工作进程是否能够写入缓存,这在使用共享缓存挂载进行 SPMD 工作负载时非常有用。
如果您想在多进程训练(使用 torch_xla.launch
或 xmp.spawn
)中使用持久编译缓存,您应该为每个进程使用不同的路径。
def _mp_fn(index):
# cache init needs to happens inside the mp_fn.
xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False)
....
if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=())
如果您没有 index
的访问权限,您可以使用 xr.global_ordinal()
。请查看此处的可运行示例 这里。
延伸阅读¶
更多文档可在 PyTorch/XLA 仓库 获取。有关在 TPU 上运行网络的更多示例,请访问 这里。