PyTorch/XLA 概述¶
本节简要概述 PyTorch XLA 的基本细节,这将有助于读者更好地理解所需的代码修改和优化。
与常规 PyTorch 不同,常规 PyTorch 是逐行执行代码,并且在获取 PyTorch 张量的值之前不会阻塞执行,PyTorch XLA 的工作方式有所不同。它会遍历 Python 代码,并将(PyTorch)XLA 张量上的操作记录在一个中间表示 (IR) 图中,直到遇到一个屏障(稍后讨论)。此生成 IR 图的过程称为跟踪(LazyTensor 跟踪或代码跟踪)。PyTorch XLA 然后将 IR 图转换为一种更低级别的机器可读格式,称为 HLO(High-Level Opcodes)。HLO 是 XLA 编译器特有的计算表示,它允许 XLA 编译器为正在运行的硬件生成高效代码。HLO 被馈送到 XLA 编译器进行编译和优化。然后,PyTorch XLA 会缓存编译结果,以便在需要时重用。图的编译在主机(CPU)上完成,也就是运行 Python 代码的机器。如果存在多个 XLA 设备,主机将为每个设备单独编译代码,除非使用 SPMD(single-program, multiple-data)。例如,v4-8 有一个主机机器和四个设备。在这种情况下,主机将为这四个设备分别编译代码。对于 pod 切片,当存在多个主机时,每个主机将为其连接的 XLA 设备执行编译。如果使用 SPMD,则对于给定的形状和计算,代码将在每个主机上为所有设备仅编译一次。
有关更多详细信息和示例,请参阅LazyTensor 指南。
IR 图中的操作仅在需要张量值时才执行。这被称为张量的求值或具体化。有时也称为延迟求值,它可以带来显著的性能提升。
PyTorch XLA 中的*同步*操作,如打印、日志记录、检查点或回调,会阻塞跟踪并导致执行速度变慢。当某个操作需要 XLA 张量的特定值时,例如print(xla_tensor_z)
,跟踪将阻塞,直到该张量的值对主机可用为止。请注意,只有负责计算该张量值的图的部分会被执行。这些操作不会切断 IR 图,但它们会通过TransferFromDevice
触发主机-设备通信,从而导致性能下降。
一个*屏障*是一个特殊的指令,它告诉 XLA 执行 IR 图并具体化张量。这意味着 PyTorch XLA 张量将被求值,并且结果将对主机可用。PyTorch XLA 中用户暴露的屏障是xm.mark_step(),它会中断 IR 图,并导致在 XLA 设备上执行代码。`xm.mark_step` 的一个关键属性是,与同步操作不同,它在设备执行图时不会阻塞进一步的跟踪。但是,它确实会阻塞对正在具体化的张量值的访问。
LazyTensor 指南中的示例说明了一个简单地将两个张量相加的简单情况。现在,假设我们有一个 for 循环,它会添加 XLA 张量并在之后使用该值。
for x, y in tensors_on_device:
z += x + y
没有屏障,Python 跟踪将生成一个图,该图将len(tensors_on_device)
次相加张量的操作。这是因为for
循环没有被跟踪捕获,因此循环的每次迭代都会创建一个新的子图,对应于z += x+y
的计算,并将其添加到图中。下面是一个len(tensors_on_device)=3
的例子。

然而,在循环结束时引入屏障将生成一个较小的图,该图将在for
循环的第一次传递期间编译一次,并将在接下来的len(tensors_on_device)-1
次迭代中重复使用。屏障将向跟踪发出信号,表明到目前为止已跟踪的图可以提交执行,如果该图以前见过,则会重用已编译的程序。.
for x, y in tensors_on_device:
z += x + y
xm.mark_step()
在这种情况下,将有一个小型图被使用len(tensors_on_device)=3
次。

需要强调的是,在 PyTorch XLA 中,for 循环内的 Python 代码会被跟踪,并且如果末尾有屏障,则每次迭代都会构建一个新的图。这可能是一个重要的性能瓶颈。
当相同的计算在相同形状的张量上发生时,XLA 图可以被重用。如果输入或中间张量的形状发生变化,XLA 编译器将重新编译一个具有新张量形状的新图。这意味着,如果您有动态形状或您的代码不重用张量图,则在 XLA 上运行您的模型将不适合该用例。将输入填充到固定形状是避免动态形状的一种选择。否则,编译器将花费大量时间来优化和融合不会再次使用的操作。
图大小和编译时间之间的权衡也很重要。如果有一个大的 IR 图,XLA 编译器可能会花费大量时间来优化和融合操作。这可能导致编译时间很长。然而,由于编译期间执行的优化,后来的执行可能会更快。
有时值得使用xm.mark_step()
来打破 IR 图。如上所述,这将生成一个更小的图,以后可以重用。然而,缩小图可能会减少 XLA 编译器可以进行的优化。
另一个重要的考虑因素是MPDeviceLoader。一旦您的代码在 XLA 设备上运行,请考虑使用 XLA MPDeviceLoader
包装 torch 数据加载器,该加载器将数据预加载到设备以提高性能,并包含xm.mark_step()
。后者会自动中断数据批次的迭代并将它们发送执行。请注意,如果您不使用 MPDeviceLoader,则可能需要在optimizer_step()
中设置barrier=True
以启用xm.mark_step()
(如果正在运行训练作业),或者显式添加xm.mark_step()
。
TPU 设置¶
创建具有基本映像的 TPU 以使用 nightly wheels,或通过指定RUNTIME_VERSION
从稳定版本创建。
export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-8 # v4-16, v4-32, …
export RUNTIME_VERSION=tpu-vm-v4-pt-2.0 # or tpu-vm-v4-base
export TPU_NAME=your_tpu_name
gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet
如果您有一个单主机 VM(例如 v4-8),您可以 ssh 到您的 vm 并直接从 vm 运行以下命令。否则,对于 TPU pod,您可以使用--worker=all --command=""
,类似于
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone=us-central2-b \
--worker=all \
--command="pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl"
接下来,如果您使用基本映像,请安装 nightly 包和必需的库。
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
sudo apt-get install libopenblas-dev -y
sudo apt-get update && sudo apt-get install libgl1 -y # diffusion specific
参考实现¶
AI-Hypercomputer/tpu-recipies 存储库包含训练和部署许多 LLM 和扩散模型的示例。
将代码转换为 PyTorch XLA¶
修改代码的一般指南
将
cuda
替换为xm.xla_device()
删除进度条、会访问 XLA 张量值的打印操作。
减少会访问 XLA 张量值的日志记录和回调。
用 MPDeviceLoader 包装数据加载器。
进行性能分析以进一步优化代码。
记住:每个情况都是独特的,您可能需要为每种情况做不同的事情。
示例 1. 在 PyTorch Lightning 中,在单个 TPU 设备上进行 Stable Diffusion 推理¶
作为第一个示例,请考虑 PyTorch Lightning 中 stable diffusion 模型的推理代码,可以从命令行执行为
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse"
供您参考,下面描述的修改差异可以在此处找到。让我们一步一步地进行。与上面的通用指南一样,从与cuda
设备相关的更改开始。此推理代码是为在 GPU 上运行而编写的,并且cuda
出现在多个地方。通过从此行中删除model.cuda()
以及从此处删除precision_scope
开始进行更改。此外,将此行中的cuda
设备替换为xla
设备,类似于下面的代码。
接下来,该模型的特定配置使用了FrozenCLIPEmbedder
,因此我们将修改此行。为了简单起见,我们将在本教程中直接定义device
,但您也可以将device
值传递给函数。
import torch_xla.core.xla_model as xm
self.device = xm.xla_device()
代码中的另一个具有 cuda 特定代码的地方是 DDIM 调度程序。在文件顶部添加import torch_xla.core.xla_model as xm
,然后替换这些行。
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
替换
device = xm.xla_device()
attr = attr.to(torch.device(device))
接下来,您可以通过删除打印语句、禁用进度条以及减少或删除回调和日志记录来减少设备(TPU)和主机(CPU)之间的通信。这些操作需要设备停止执行,回退到 CPU,执行日志记录/回调,然后返回到设备。这可能是一个重要的性能瓶颈,尤其是在大型模型上。
进行这些更改后,代码将在 TPU 上运行。但是,性能会非常慢。这是因为 XLA 编译器尝试构建一个单一的(巨大的)图,该图包装了推理步骤的数量(在本例中为 50),因为 for 循环内没有屏障。编译器很难优化图,这会导致严重的性能下降。如上所述,使用屏障(xm.mark_step())中断 for 循环将生成一个更小的图,编译器更容易优化。这还将允许编译器重用上一步的图,从而提高性能。
现在代码已准备好在合理的时间内用于 TPU。可以通过捕获配置文件并进一步研究来进行更多优化和分析。但是,这里不包括这些内容。
注意:如果您在 v4-8 TPU 上运行,则有 4 个可用的 XLA(TPU)设备。如上运行代码将仅使用一个 XLA 设备。为了在所有 4 个设备上运行,您需要使用torch_xla.launch()
函数将代码分发到所有设备。我们将在下一个示例中讨论torch_xla.launch
。
示例 2. HF Stable Diffusion 推理¶
现在,考虑使用 HuggingFace diffusers 库中的文本到图像推理来处理 SD-XL 和 2.1 版本模型。供您参考,下面描述的更改可以在此repo中找到。您可以克隆存储库并使用以下命令在 TPU VM 上运行推理。
(vm)$ git clone https://github.com/pytorch-tpu/diffusers.git
(vm)$ cd diffusers/examples/text_to_image/
(vm)$ python3 inference_tpu_single_device.py
在单个 TPU 设备上运行¶
本节介绍将文本到图像推理示例代码在 TPU 上运行所需的更改。
原始代码使用 Lora 进行推理,但本教程不使用它。相反,我们在初始化管道时将model_id
参数设置为stabilityai/stable-diffusion-xl-base-0.9
。我们还将使用默认调度程序(DPMSolverMultistepScheduler)。但是,类似更改也可以应用于其他调度程序。
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install . # pip install -e .
cd examples/text_to_image/
pip install -r requirements.txt
pip install invisible_watermark transformers accelerate safetensors
(如果找不到accelerate
,请注销并重新登录。)
登录 HF 并同意模型卡上的sd-xl 0.9 许可证。接下来,转到帐户→设置→访问令牌并生成新令牌。复制令牌并在您的 vm 上使用该特定令牌值运行以下命令。
(vm)$ huggingface-cli login --token _your_copied_token__
HuggingFace 的自述文件提供了用于在 GPU 上运行的 PyTorch 代码。要在 TPU 上运行它,第一步是将 CUDA 设备更改为 XLA 设备。这可以通过将行pipe.to("cuda")
替换为以下行来完成。
import torch_xla.core.xla_model as xm
device = xm.xla_device()
pipe.to(device)
此外,需要注意的是,第一次使用 XLA 进行推理时,编译需要很长时间。例如,HuggingFace 的 stable diffusion XL 模型推理编译时间可能需要一个小时,而实际推理可能只需要 5 秒,具体取决于批次大小。同样,GPT-2 模型可能需要大约 10-15 分钟才能编译,之后训练 epoch 时间会快得多。这是因为 XLA 构建了将要执行的计算图,然后为运行它的特定硬件优化该图。但是,一旦图编译完成,它就可以重用于后续推理,这将快得多。因此,如果您只运行一次推理,您可能不会从使用 XLA 中获益。但是,如果您多次运行推理,或者在提示列表上运行推理,那么在最初几次推理之后,您将开始看到 XLA 的优势。例如,如果您在 10 个提示的列表上运行推理,第一次推理(可能是一两次[^1])可能需要很长时间来编译,但其余的推理步骤将快得多。这是因为 XLA 将重用它为第一次推理编译的图。
如果您尝试在不进行任何其他更改的情况下运行代码,您会发现编译时间非常长 (>6 小时)。这是因为 XLA 编译器尝试一次构建一个单一的图来处理所有调度程序步骤,这与我们在上一个示例中讨论的类似。为了使代码运行得更快,我们需要使用xm.mark_step()
将图分解成更小的部分,并在后续步骤中重用它们。这发生在pipe.__call__
函数的这些行中。禁用进度条、删除回调并将xm.mark_step()
添加到 for 循环的末尾可以显著加快代码速度。更改已在此commit中提供。
此外,self.scheduler.step()
函数,它默认使用DPMSolverMultistepScheduler
调度程序,存在一些问题,这些问题在PyTorch XLA 限制中有所描述。此函数中的.nonzero()
和.item()
调用会向 CPU 发送张量求值请求,这会触发设备-主机通信。这是不理想的,因为它会减慢代码速度。在这种特定情况下,我们可以通过直接传递索引来避免这些调用。这将阻止函数向 CPU 发送请求,并提高代码的性能。更改可在此 commit 中获得。代码现在已准备好在 TPU 上运行。
性能分析和性能分析¶
为了进一步调查模型的性能,我们可以使用性能分析指南对其进行性能分析。经验法则是,应使用适合内存的最大批次大小运行性能分析脚本,以实现最佳内存使用。它还有助于重叠代码跟踪与设备执行,从而实现更优的设备利用率。性能分析的持续时间应足够长,以捕获至少一个步骤。TPU 上的模型良好性能意味着设备-主机通信已最小化,并且设备正在持续运行进程,没有空闲时间。
在inference_tpu_*.py
文件中启动服务器并运行capture_profile.py
脚本,如指南中所述,将为我们提供在设备上运行的进程信息。目前,只有一台 XLA 设备正在被分析。为了更好地理解 TPU 空闲时间(配置文件中的间隙),应将性能分析跟踪(xp.Trace()
)添加到代码中。xp.Trace()
测量在主机(CPU)上运行的 Python 代码的跟踪时间,这些代码被包装在跟踪中。在此示例中,xp.Trace()
跟踪已添加到管道和 U-net 模型中,以测量在主机(CPU)上运行特定代码段的时间。
如果配置文件中的间隙是由发生在主机上的 Python 代码跟踪引起的,那么这可能是一个瓶颈,并且目前没有进一步的直接优化方法。否则,应进一步分析代码以理解限制并进一步提高性能。请注意,您不能xp.Trace()
包装调用了xm.mark_step()
的代码部分。
为了说明这一点,我们可以查看已捕获的配置文件,这些配置文件已按照性能分析指南上传到 tensorboard。
从 Stable Diffusion 模型版本 2.1 开始
如果我们捕获一个配置文件而不插入任何跟踪,我们将看到以下内容。

v4-8 上的单个 TPU 设备有两个核心,似乎正在运行。除了中间的一个小间隙外,它们的使用没有明显间隙。如果我们向上滚动以查找占用主机机器的进程,我们将找不到任何信息。因此,我们将xp.traces
添加到管道文件以及 U-net 函数。后者可能对此特定用例没有用,但它确实演示了如何将跟踪添加到不同位置以及它们的信息如何在 TensorBoard 中显示。
如果我们添加跟踪并重新捕获配置文件(使用适合设备的`最大`批次大小(此处为 32)),我们将看到设备中的间隙是由运行在主机上的 Python 进程引起的。

我们可以使用适当的工具放大时间线,查看在那个时期运行的进程。这时发生在主机上的 Python 代码跟踪,我们目前无法进一步优化跟踪。
现在,让我们检查 XL 版本模型并执行相同的操作。我们将像处理 2.1 版本一样,将跟踪添加到管道文件中,并捕获一个配置文件。

这次,除了中间的大间隙(由pipe_watermark
跟踪引起)之外,在此循环内的推理步骤之间存在许多小间隙。
首先,让我们仔细看看由pipe_watermark
引起的大间隙。间隙前面有TransferFromDevice
,这表明主机上正在发生一些事情,并且它正在等待计算完成才能继续。查看 watermark 代码,我们可以看到张量被传输到 CPU 并转换为 numpy 数组,以便稍后用cv2
和pywt
库进行处理。由于这部分不容易优化,我们将保持原样。
现在如果我们放大循环,我们可以看到循环内的图被分解成更小的部分,因为发生了TransferFromDevice
操作。

如果我们检查 U-Net 函数和调度程序,我们可以看到 U-Net 代码不包含 PyTorch/XLA 的优化目标。但是,在scheduler.step内部有一些.item()
和.nonzero()
调用。我们可以重写该函数以避免这些调用。如果我们修复此问题并重新运行配置文件,我们不会看到太大差异。但是,由于我们减少了引入更小图的设备-主机通信,我们允许编译器更好地优化代码。函数scale_model_input存在类似问题,我们可以通过对step
函数进行更改来解决这些问题。总的来说,由于许多间隙是由 Python 级别代码跟踪和图构建引起的,因此使用当前版本的 PyTorch XLA 无法优化这些间隙,但随着将来在 PyTorch XLA 中启用 dynamo,我们可能会看到改进。
在多个 TPU 设备上运行¶
要使用多个 TPU 设备,您可以使用torch_xla.launch
函数来分发您在单个设备上运行的函数到多个设备。torch_xla.launch
函数将在多个 TPU 设备上启动进程并在需要时同步它们。这可以通过将index
参数传递给在单个设备上运行的函数来完成。例如。
import torch_xla
def my_function(index):
# function that runs on a single device
torch_xla.launch(my_function, args=(0,))
在此示例中,my_function
函数将在 v4-8 上的 4 个 TPU 设备上分发,每个设备被分配一个从 0 到 3 的索引。请注意,默认情况下,launch() 函数将在所有 TPU 设备上分发进程。如果您只想运行单个进程,请设置参数launch(..., debug_single_process=True)
。
此文件说明了如何使用 xmp.spawn 在多个 TPU 设备上运行 stable diffusion 2.1 版本。对于此版本,与上述更改类似,已对管道文件进行了更改。
在 Pod 上运行¶
一旦您有了在单主机设备上运行的代码,就不需要进一步的更改了。您可以创建 TPU pod,例如,按照这些说明进行操作。然后使用以下命令运行您的脚本。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone=${ZONE} \
--worker=all \
--command="python3 your_script.py"
注意
0 和 1 是 XLA 中的魔术数字,在 HLO 中被视为常量。因此,如果代码中有一个可以生成这些值的随机数生成器,代码将为每个值单独编译。这可以通过XLA_NO_SPECIAL_SCALARS=1
环境变量禁用。