PJRT Runtime¶
PyTorch/XLA 已从基于 TensorFlow 的 XRT runtime 迁移到 JAX 使用的 PJRT runtime(位于 PJRT runtime)。
如果您遇到 PJRT 的 bug,请在 GitHub 上提交 issue,并加上 runtime
标签。
PyTorch/XLA r2.1 中的新功能:
PJRT 在 PyTorch/XLA r2.1 中已稳定!
公共 runtime API 已从
torch_xla.experimental.pjrt
迁移到torch_xla.runtime
。pjrt://
初始化方法已重命名为xla://
,并由torch_xla.distributed.xla_backend
注册。为了兼容性,本版本中仍然提供之前的
torch_xla.experimental.*
名称。
现在使用
init_method='xla://'
时支持torchrun
。通过 PJRT C API 为 XPU 和 Neuron 添加了新的插件。
PyTorch/XLA r2.0 中的新功能:
如果您不指定任何其他 runtime 配置,PJRT 将被默认配置。如果您继续设置 XRT 配置(
XRT_TPU_CONFIG
),此更改没有影响。libtpu
中的新 TPU runtime 实现将性能提高了高达 30%。新的
xm.rendezvous
实现,可扩展到数千个 TPU 核心。[experimental]
torch.distributed
支持 TPU v2 和 v3,包括pjrt://
init_method
。
TL;DR¶
要使用 PJRT 预览 runtime,请将
PJRT_DEVICE
环境变量设置为CPU
、TPU
或CUDA
。在 XRT 中,所有分布式工作负载都是多进程的,每个设备一个进程。在 PJRT 中,TPU v2 和 v3 上的工作负载是多进程和多线程的(4 个进程,每个进程 2 个线程),因此您的工作负载应该是线程安全的。有关更多信息,请参阅 TPU v2/v3 上的多线程 和 API 指南的多进程部分。关键区别需要牢记:
为了以线程安全的方式初始化模型,请在初始化后跨副本广播参数(
torch_xla.experimental.pjrt.broadcast_master_param
),或者从通用检查点加载每个副本的参数。对于其他随机数生成,请尽可能使用
torch.Generator
。全局torch
RNG **不是**线程安全的,即使您在每个副本上设置了相同的torch.manual_seed
。要使用
torch.distributed
,请导入torch_xla.experimental.pjrt_backend
并使用xla://
init_method
。对于 GPU 和 TPU v4,这些步骤是可选的。
XRT 到 PJRT 的示例差异
import os
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
+import torch_xla.runtime as xr
def _mp_fn(index):
device = xm.xla_device()
- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+ dist.init_process_group('xla', init_method='xla://')
torch.manual_seed(42)
model = nn.Linear(128, 10).to(device)
+ # Optional for TPU v4 and GPU
+ xm.broadcast_master_param(model)
model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
for i in range(10):
data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Print mean parameters so we can confirm they're the same across replicas
print([p.mean() for p in model.parameters()])
if __name__ == '__main__':
- os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
+ # Recommended: set PJRT_DEVICE to your local device type
+ os.environ['PJRT_DEVICE'] = 'TPU'
torch_xla.launch(_mp_fn)
优点¶
简单的 runtime 配置:只需将
PJRT_DEVICE
设置为TPU
、CPU
或CUDA
,即可开始使用 XLA!或者,让 PJRT 根据您的环境自动选择设备。性能提升:gRPC 开销减少,从而实现更快的端到端执行。在 TorchBench 2.0 上,我们在 TPU v4 上的训练时间方面观察到了超过 35% 的提升。
轻松执行 Pod:只需将代码复制到每个 TPU worker,然后使用
gcloud compute tpus tpuvm ssh --worker=all
同时执行它们。更好的扩展性:消除了 XRT 对参数大小的限制,并支持高达 2048 个 TPU 芯片。
快速入门¶
要开始使用 PyTorch/XLA 的 PJRT,您只需设置 PJRT_DEVICE
环境变量。如果您正在使用 TPU v2 或 v3,请继续阅读以了解 TPU v2 和 v3 与 v4 之间的区别。
CPU¶
在任何安装了 PyTorch/XLA 的机器上,您都可以像这样在 CPU 上运行我们的 MNIST 示例:
PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data
TPU¶
要使用安装了 PyTorch/XLA r2.0 的新 TPU:
gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT
在 v4-8 上,您可以像这样运行我们的 ResNet50 示例:
git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
默认情况下,PJRT 将使用所有 TPU 芯片。要只使用一个 TPU 芯片,请配置 TPU_PROCESS_BOUNDS
和 TPU_VISIBLE_CHIPS
。
TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
Pods¶
在 TPU Pods 上,使用 gcloud
在每个 TPU 上并行运行您的命令。
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
Docker¶
您还可以使用 Docker 在预装了 PyTorch/XLA 的容器中运行您的工作负载。
export DOCKER_IMAGE=gcr.io/...
# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"
# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"
请注意,docker run
需要对主机的特权访问(--privileged
)才能将 TPU 设备暴露给容器。目前,TPU Pod 上的 Docker 只支持主机网络模式(--net=host
)。有关更多信息,请参阅 Cloud TPU 文档。
GPU¶
单节点 GPU 训练¶
要将 GPU 与 PJRT 结合使用,只需设置 PJRT_DEVICE=CUDA
并将 GPU_NUM_DEVICES
配置为主机上的设备数量。例如:
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1
您还可以使用 torchrun
来启动单节点多 GPU 训练。例如:
PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
在上面的示例中,--nnodes
表示要使用的机器(物理机或虚拟机)数量(单节点训练时为 1)。--nproc-per-node
表示要使用的 GPU 设备数量。
多节点 GPU 训练¶
请注意,此功能仅适用于 cuda 12+。与 PyTorch 使用多节点训练类似,您可以如下运行命令:
PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
--nnodes
:要使用的 GPU 机器数量。--node_rank
:当前 GPU 机器的索引。该值可以是 0、1、...、${NUMBER_GPU_VM}-1。--nproc_per_node
:当前机器上要使用的 GPU 设备数量。--rdzv_endpoint
:node_rank==0 的 GPU 机器的端点,格式为host:port
。host
将是内部 IP 地址。port
可以是机器上的任何可用端口。对于单节点训练/推理,可以省略此参数。
例如,如果您想在 2 台 GPU 机器(machine_0 和 machine_1)上进行训练,在第一台 GPU 机器 machine_0 上运行:
# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
在第二台 GPU 机器上运行:
# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
以上两个命令的区别在于 --node_rank
,以及如果您想在每台机器上使用不同数量的 GPU 设备,则可能还有 --nproc_per_node
。其他所有内容都相同。有关 torchrun
的更多信息,请参阅此页面。
与 XRT 的区别¶
尽管在大多数情况下,我们期望 PJRT 和 XRT 从最终用户的角度来看(尤其是在 TPU v4 上)都能基本互换工作,但存在一些细微的差别,这些差别很重要。重要的是,XRT 是围绕 TPU Node 架构设计的,因此即使在 TPU VM 上,它也总是会启动一个客户端和一个服务器进程。因此,每批输入都会因序列化和反序列化数据以通过网络发送而产生额外的延迟。
PJRT 直接使用本地设备,没有中间服务器进程。在默认配置中,PJRT 将为每个 TPU 芯片创建一个进程,或为每个 TPU 主机创建 4 个进程。有关 TPU 架构的更多信息,请参阅 Cloud TPU 文档。
对于受限制于开销的工作负载,可以获得性能提升。
在 XRT 下,服务器进程是唯一与 TPU 设备交互的进程,而客户端进程无法直接访问 TPU 设备。在对单主机 TPU(例如 v3-8 或 v4-8)进行分析时,通常会看到 8 个设备跟踪(每个 TPU 核心一个)。使用 PJRT,每个进程有一个芯片,并且来自该进程的分析将仅显示 2 个 TPU 核心。
出于同样的原因,在 TPU Pod 上使用 XRT 进行分析无效,因为服务器进程独立于用户的模型代码运行。PJRT 没有这个限制,因此可以在 TPU Pod 中对每个进程的 2 个 TPU 核心进行分析。
PJRT 只支持 TPU VM 架构,我们不计划使用 PJRT 支持 TPU Node 架构。
使用 PJRT 的 runtime 配置显著简化。
xla_dist
不需要运行 TPU Pod 工作负载。相反,将代码复制到每个 TPU 主机([gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)
),然后并行在每个主机上运行代码(例如[gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh)
)。xm.rendezvous
已使用 XLA 原生集体通信重新实现,以提高在大规模 TPU Pod 上的稳定性。有关更多详细信息,请参阅下文。
TPU v2/v3 上的多线程¶
在 TPU v2 和 v3 上,**分布式工作负载始终以多线程方式运行**,因为每个 TPU 核心将两个 TPU 核心公开为设备,并且一次只有一个进程可以打开 TPU 芯片。在其默认配置中,xmp.spawn
会自动生成尽可能多的进程(每个 TPU 主机 4 个)并为每个进程创建两个线程(每个 TPU 核心一个)。
注意:在 TPU v4 上,每个 TPU 芯片表示为一个 PyTorch 设备,因此分布式工作负载将在 4 个进程中运行,每个进程只有一个线程。这与 XRT 的行为相同。
在大多数情况下,这不需要对现有代码进行大量修改。在大多数情况下,您需要做出的主要更改是模型初始化。由于 torch
的全局 RNG 在线程之间共享,因此即使您在每个副本中将 torch.manual_seed
设置为相同的值,不同线程和运行结果也会有所不同。为了在副本之间获得一致的参数,请使用 torch_xla.experimental.pjrt.broadcast_master_param
将一个副本的参数广播到所有其他副本,或者从通用检查点加载每个副本的参数。
xm.rendezvous 的更改¶
PyTorch/XLA r2.0 新增内容
在使用 XRT 时,worker 0 运行一个 mesh master 服务,并且所有 worker 上的所有进程通过 gRPC 连接到该服务。实际上,我们发现由于 worker 0 的入站连接数量过多,运行单个 mesh master 进程在拥有数千个芯片的 TPU Pod 上并不可靠。单个客户端进程超时可能会导致失败并迫使整个工作负载重新启动。
因此,我们使用 XLA 原生集体通信重新实现了 xm.rendezvous
,它在大规模 TPU Pod 上更稳定且经过充分测试。与 XRT 实现相比,这带来了两个新的约束:
由于有效载荷必须成为 XLA 图的一部分,因此
xm.mark_step
在数据传输之前和之后都会被调用。在模型代码中间调用xm.rendezvous
可能会强制进行不必要的编译。由于 XLA 不允许在部分 worker 上运行集体操作,因此所有 worker 都必须参与
rendezvous
。
如果您需要 xm.rendezvous
的旧行为(即,在不修改 XLA 图的情况下通信数据和/或同步部分 worker),请考虑使用 `torch.distributed.barrier
<https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.barrier>[__ 或 ]{.title-ref}torch.distributed.all_gather_object
<https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.all_gather_object>[__ 与 ]{.title-ref}[gloo]{.title-ref}[ 进程组。如果您还使用 ]{.title-ref}[xla]{.title-ref}[ ]{.title-ref}[torch.distributed]{.title-ref}[ 后端,您可以使用 ]{.title-ref}[torch.new*group]{.title-ref}[ 来创建一个 ]{.title-ref}[gloo]{.title-ref}[ 子组。请参阅 PyTorch 文档中的`此示例 https://pytorch.ac.cn/docs/stable/distributed.html#monitored-barrier]{.title-ref}*。请记住这些约束:
torch.distributed
在 TPU v2/v3 上并未完全支持。仅实现了带有xla
后端的子集操作,并且gloo
在多线程环境中可能无法按预期工作。在我们的实验中,
gloo
在数千个 TPU 芯片上的扩展性不佳,因此预计此替代方案在大规模使用xm.rendezvous
与 PJRT 相比可靠性较低。
PJRT 和 torch.distributed¶
PyTorch/XLA r2.0 新增内容
当使用 PJRT 结合 torch.distributed
和 [torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md)
时,我们强烈建议使用新的 xla://
init_method
,它通过查询 runtime 自动查找副本 ID、世界大小和主 IP。例如:
import torch
import torch_xla
import torch.distributed as dist
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt
# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend
def _all_gather(index: int):
# No need to pass in `rank` or `world_size`
dist.init_process_group('xla', init_method='xla://')
t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(output, t)
xm.mark_step()
print(output)
if __name__ == '__main__':
torch_xla.launch(_all_gather)
注意:尽管 xla://
init_method 在 TPU v4 上不是必需的,但仍然推荐使用。如果您使用 env://
,则必须将 MASTER_ADDR
设置为拥有设备 0 的 IP 主机,而这**不**总是 worker 0。xla://
init_method 会自动找到此 IP。
注意:对于 TPU v2/v3,您仍然需要导入 torch_xla.experimental.pjrt_backend
,因为 torch.distributed
中对 TPU v2/v3 的支持仍处于实验阶段。
有关在 PyTorch/XLA 上使用 DistributedDataParallel
的更多信息,请参阅 TPU V4 上的 ddp.md。对于同时使用 DDP 和 PJRT 的示例,请在 TPU 上运行以下示例脚本。
PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1
性能¶
TorchBench 显示,与 XRT 相比,PJRT 在任务平均训练时间上有所提高,在 TPU v4-8 上平均提高超过 35%。收益因任务和模型类型而异,范围从 0% 到 175%。下表按任务分解了详细信息:
新的 TPU runtime¶
PyTorch/XLA r2.0 新增内容
PyTorch/XLA r2.0 版本引入了对 PJRT Plugin API 的支持,该 API 用于访问 libtpu
中新的基于 TFRT 的 TPU runtime。现在,当设置 PJRT_DEVICE=TPU
时,它是默认的 runtime。在 2.0 版本中,1.13 版本中使用的旧版 StreamExecutor 基于的 TPU runtime 仍将可用,使用 PJRT_DEVICE=TPU_LEGACY
,但它将在未来的版本中被移除。如果您遇到仅在 TPU
上而不是 TPU_LEGACY
上发生的 issue,请在 GitHub 上提交 issue。
在大多数情况下,我们预计两个 runtime 的性能相似,但在某些情况下,新 runtime 的速度可能提高高达 30%。下表按任务分解了详细信息:
注意:此表中显示的改进也包含在 PJRT 与 XRT 的比较中。