故障排除¶
请注意,本节中的信息可能会在未来版本的PyTorch/XLA软件中被移除,因为其中许多信息是特定于给定内部实现的,可能会发生更改。
健全性检查¶
在进行任何深入调试之前,我们想对已安装的 PyTorch/XLA 进行一次健全性检查。
检查 PyTorch/XLA 版本¶
PyTorch 和 PyTorch/XLA 版本应匹配。有关可用版本,请参阅我们的自述文件。
vm:~$ python
>>> import torch
>>> import torch_xla
>>> print(torch.__version__)
2.1.0+cu121
>>> print(torch_xla.__version__)
2.1.0
执行简单计算¶
vm:~$ export PJRT_DEVICE=TPU
vm:~$ python3
>>> import torch
>>> import torch_xla.core.xla_model as xm
>>> t1 = torch.tensor(100, device=xm.xla_device())
>>> t2 = torch.tensor(200, device=xm.xla_device())
>>> print(t1 + t2)
tensor(300, device='xla:0')
使用假数据运行 Resnet¶
对于 nightly 版本
vm:~$ git clone https://github.com/pytorch/xla.git
vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data
对于发布版本 x.y
,您需要使用 rx.y
分支。例如,如果您安装了 2.1 版本,则应执行
vm:~$ git clone --branch r2.1 https://github.com/pytorch/xla.git
vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data
如果我们能运行 resnet,那么我们可以得出结论 torch_xla 已正确安装。
性能调试¶
为了诊断性能问题,我们可以使用PyTorch/XLA提供的执行指标和计数器。当模型速度慢时,**首要**检查的是生成指标报告。
指标报告对于诊断问题非常有帮助。如果您遇到问题,请在提交给我们的错误报告中包含它。
PyTorch/XLA 调试工具¶
您可以通过设置 PT_XLA_DEBUG_LEVEL=2
来启用 PyTorch/XLA 调试工具,它提供了几个有用的调试功能。您也可以将调试级别降低到 1
来进行执行分析。
执行自动指标分析¶
调试工具将分析指标报告并提供摘要。一些示例如下:
pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps
pt-xla-profiler: TransferFromDeviceTime too frequent: 11 counts during 11 steps
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests.
pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps
pt-xla-profiler: TransferFromDeviceTime too frequent: 12 counts during 12 steps
编译与执行分析¶
调试工具将分析您的模型的每一次编译和执行。一些示例如下:
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: mark_step in parallel loader at step end
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: c74c3b91b855b2b123f833b0d5f86943
Compilation Analysis: Number of Graph Inputs: 35
Compilation Analysis: Number of Graph Outputs: 107
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055)
Compilation Analysis: next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44)
Compilation Analysis: __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32)
Compilation Analysis: train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48)
Compilation Analysis: start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65)
Compilation Analysis: <module> (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73)
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================
Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 1.548000 GB
Post Compilation Analysis: Graph output size: 7.922460 GB
Post Compilation Analysis: Aliased Input size: 1.547871 GB
Post Compilation Analysis: Intermediate tensor size: 12.124478 GB
Post Compilation Analysis: Compiled program size: 0.028210 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: mark_step in parallel loader at step end
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: c74c3b91b855b2b123f833b0d5f86943
Execution Analysis: Number of Graph Inputs: 35
Execution Analysis: Number of Graph Outputs: 107
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055)
Execution Analysis: next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44)
Execution Analysis: __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32)
Execution Analysis: train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48)
Execution Analysis: start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65)
Execution Analysis: <module> (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
编译/执行的常见原因包括:1. 用户手动调用 mark_step
。2. Parallel loader 为每个 x(可配置)批次调用 mark_step
。3. 退出 profiler StepTrace 区域。
Dynamo 决定编译/执行图。5. 用户尝试在
mark_step
之前访问(通常是由于日志记录)张量的值。
由 1-4 引起的执行是预期的,我们希望通过减少访问张量值频率或手动添加 mark_step
来避免 5。
用户应该期望看到前几个步骤中的 Compilation Cause
+ Executation Cause
对。在模型稳定后,用户应该只看到 Execution Cause
(您可以通过 PT_XLA_DEBUG_LEVEL=1
禁用执行分析)。为了高效地使用 PyTorch/XLA,我们期望相同的模型代码在每个步骤中运行,并且每个图只编译一次。如果您不断看到 Compilation Cause
,您应该尝试按照本节的说明转储 IR/HLO,并比较每个步骤的图,以了解差异的来源。
下一节将解释如何获取和理解更详细的指标报告。
获取指标报告¶
在您的程序中添加以下行以生成报告:
import torch_xla.debug.metrics as met
# For short report that only contains a few key metrics.
print(met.short_metrics_report())
# For full report that includes all metrics.
print(met.metrics_report())
理解指标报告¶
报告包括:- 我们发出XLA编译的次数以及发出所需的时间。- 执行的次数以及执行所需的时间- 我们创建/销毁的设备数据句柄的数量等。
这些信息以样本的百分位数形式报告。例如:
Metric: CompileTime
TotalSamples: 202
Counter: 06m09s401ms746.001us
ValueRate: 778ms572.062us / second
Rate: 0.425201 / second
Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us
我们还提供计数器,它们是命名的整数变量,用于跟踪内部软件状态。例如:
Counter: CachedSyncTensors
Value: 395
在此报告中,任何以 aten::
开头的计数器都表示 XLA 设备和 CPU 之间的上下文切换,这可能是模型代码中潜在的性能优化区域。
计数器有助于了解哪些操作被路由回 PyTorch 的 CPU 引擎。它们已通过其 C++ 命名空间完全限定。
Counter: aten::nonzero
Value: 33
如果您看到 aten::
操作(除了 nonzero
和 _local_scalar_dense
)之外的操作,这通常意味着 PyTorch/XLA 中缺少低级实现。请随时在 GitHub issues 上为此打开一个功能请求。
PyTorch/XLA + Dynamo 调试工具¶
您可以通过设置 XLA_DYNAMO_DEBUG=1
来启用 PyTorch/XLA + Dynamo 调试工具。
已知的性能注意事项¶
PyTorch/XLA 在语义上与常规 PyTorch 相同,XLA 张量与 CPU 和 GPU 张量共享完整的张量接口。但是,XLA/硬件的限制以及惰性求值模型表明某些模式可能会导致性能不佳。
如果您的模型显示性能不佳,请牢记以下注意事项:
XLA/TPU 在过多的重新编译时性能会下降。
XLA 编译成本很高。PyTorch/XLA 会在每次遇到新形状时自动重新编译图。通常模型会在几步内稳定下来,而您可以在其余训练中看到巨大的速度提升。
为了避免重新编译,不仅形状必须恒定,所有主机中的 XLA 设备上的计算也必须是恒定的。
可能的来源:
直接或间接使用
nonzero
会引入动态形状;例如,掩码索引base[index]
,其中index
是一个掩码张量。迭代次数在步骤之间不同的循环会导致不同的执行图,从而需要重新编译。
解决方案:
张量形状在迭代之间应相同,或应使用少量形状变化。
如果可能,将张量填充到固定大小。
某些操作没有直接转换为 XLA 的方法。
对于这些操作,PyTorch/XLA 会自动将其传输到 CPU 内存,在 CPU 上求值,然后将结果传输回 XLA 设备。在训练步骤中执行过多此类操作可能导致显着的速度下降。
可能的来源:
item()
操作明确要求对结果进行求值。除非必要,否则不要使用它。
解决方案:
对于大多数操作,我们可以将其低级化到 XLA 来修复。请查看 指标报告部分 以找出缺失的操作,并在 GitHub 上打开一个功能请求。
即使 PyTorch 张量被认为是标量,也要避免使用 tensor.item()。将其保留为张量,并在其上使用张量操作。
在适用时使用
torch.where
替换控制流。例如,带有item()
的控制流在 clip_grad_norm 中使用,这存在问题并且会影响性能,因此我们修补了clip_grad_norm_
,通过调用torch.where
来替代,这给我们带来了显著的性能提升。... else: device = parameters[0].device total_norm = torch.zeros([], device=device if parameters else None) for p in parameters: param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm) total_norm = (total_norm ** (1. / norm_type)) clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) for p in parameters: p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
``torch_xla.distributed.data_parallel`` 中的迭代器可能会丢失输入迭代器中最后几个批次的数据。
这是为了确保所有 XLA 设备执行相同的工作量。
解决方案:
当数据集很小,且轮次太少时,这可能导致一个空的 epoch。因此,在这种情况下最好使用小的批次大小。
XLA 张量怪癖¶
XLA 张量内部是不透明的。 XLA 张量始终显示为连续的且没有存储。网络不应尝试检查 XLA 张量的跨度。
XLA 张量在保存之前应移至 CPU。 直接保存 XLA 张量会导致它们被重新加载到保存它们的设备上。如果在加载时设备不可用,则加载将失败。在保存之前将 XLA 张量移至 CPU 可让您决定加载张量时要放在哪些设备上。如果要在没有 XLA 设备的机器上加载张量,这是必需的。但是,在将 XLA 张量保存到 CPU 之前,应小心移动它们,因为跨设备类型移动张量不会保留视图关系。相反,视图应在加载张量后根据需要进行重建。
使用 Python 的 copy.copy 复制 XLA 张量会返回深拷贝,而不是浅拷贝。 使用 XLA 张量的视图来获得其浅拷贝。
处理共享权重。 模块可以通过将一个模块的参数设置为另一个模块来共享权重。这种模块权重的“绑定”应在模块移至 XLA 设备**之后**完成。否则,将在 XLA 设备上创建共享张量的两个独立副本。
更多调试工具¶
我们不期望用户使用本节中的工具来调试他们的模型。但当您提交错误报告时,我们可能会要求提供这些工具,因为它们提供了指标报告没有的额外信息。
调试张量操作¶
以下工具对于收集已低级化操作执行信息很有用。
print(torch_xla._XLAC._get_xla_tensors_text([res]))
,其中res
是结果张量,将打印 IR。print(torch_xla._XLAC._get_xla_tensors_hlo([res]))
,其中res
是结果张量,将打印生成的 XLA HLO。
请注意,这些函数必须在 mark_step()
之前调用,否则张量将已被物化。
环境变量¶
还有许多环境变量控制着PyTorch/XLA软件栈的行为。
设置这些变量会导致不同程度的性能下降,因此它们应该只在调试时启用。
XLA_IR_DEBUG
:启用在创建 IR 节点时捕获Python堆栈跟踪,从而可以了解哪个PyTorch操作负责生成 IR。XLA_HLO_DEBUG
:启用在XLA_IR_DEBUG处于活动状态时捕获的Python堆栈帧,以传播到XLAHLO元数据。XLA_SAVE_TENSORS_FILE
:将用于在执行过程中转储 IR 图的文件路径。请注意,如果该选项保持启用状态且PyTorch程序运行时间较长,文件可能会变得非常大。图会被追加到文件中,因此要从运行到运行的干净记录,应明确删除该文件。XLA_SAVE_TENSORS_FMT
:存储在XLA_SAVE_TENSORS_FILE文件中的图的格式。可以是text
(默认),dot
(Graphviz格式)或hlo
。XLA_FLAGS=--xla_dump_to
:如果设置为=/tmp/dir_name
,XLA 编译器将在每次编译时转储未优化和优化的 HLO。XLA_METRICS_FILE
:如果设置,则为将内部指标保存在每个步骤中的本地文件路径。如果文件已存在,则指标将被追加到该文件中。XLA_SAVE_HLO_FILE
:如果设置,则为在发生编译/执行错误时,将错误 HLO 图保存到的本地文件路径。XLA_SYNC_WAIT
:在移动到下一步之前,强制 XLA 张量同步操作等待其完成。XLA_USE_EAGER_DEBUG_MODE
:强制 XLA 张量进行即时执行,即逐个编译和执行 torch 操作。这有助于绕过长时间的编译时间,但总体步骤时间会慢得多,内存使用量也会更高,因为所有编译器优化都将被跳过。TF_CPP_LOG_THREAD_ID
:如果设置为 1,TF 日志将显示线程 ID,有助于调试多线程进程。TF_CPP_VMODULE
:用于 TF VLOG 的环境变量,形式为TF_CPP_VMODULE=name=value,...
。请注意,对于 VLOG,您必须将TF_CPP_MIN_LOG_LEVEL=0
设置为。TF_CPP_MIN_LOG_LEVEL
:要打印消息的级别。TF_CPP_MIN_LOG_LEVEL=0
将启用 INFO 日志记录,TF_CPP_MIN_LOG_LEVEL=1
警告,依此类推。我们的 PyTorch/XLATF_VLOG
默认使用tensorflow::INFO
级别,因此要查看 VLOG,请将TF_CPP_MIN_LOG_LEVEL=0
设置为。XLA_DUMP_HLO_GRAPH
:如果设置为=1
,在发生编译或执行错误时,将作为xla_util.cc
抛出的运行时错误的一部分,转储有问题的 HLO 图。
常见的调试环境变量组合¶
以 IR 格式记录图执行
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="text" XLA_SAVE_TENSORS_FILE="/tmp/save1.ir"
以 HLO 格式记录图执行
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo"
显示运行时和图编译/执行的调试 VLOG
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=5,pjrt_computation_client=3"
重现 PyTorch/XLA CI/CD 单元测试失败。¶
您可能会看到 PR 的某些测试失败,例如:
要执行此测试,请从基础 repo 目录运行以下命令:
PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8
直接在命令行中运行此命令不起作用。您需要将环境变量 TORCH_TEST_DEVICES
设置为您本地的 pytorch/xla/test/pytorch_test_base.py
。例如:
TORCH_TEST_DEVICES=/path/to/pytorch/xla/test/pytorch_test_base.py PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8
应该可以工作。