自动加载树外扩展#
作者: Yuanhao Ji
扩展自动加载机制使 PyTorch 能够在无需显式导入语句的情况下自动加载树外(out-of-tree)后端扩展。此功能对用户非常有益,因为它提升了使用体验,使用户能够遵循熟悉的 PyTorch 设备编程模型,而无需显式加载或导入特定设备的扩展。此外,它还促进了现有 PyTorch 应用程序在树外设备上的轻松适配,且无需更改任何代码。有关详细信息,请参阅 [RFC] Autoload Device Extension。
如何在 PyTorch 中使用树外扩展自动加载
查看 Intel Gaudi HPU 和华为 Ascend NPU 的示例
PyTorch v2.5 或更高版本
注意
此功能默认启用,可以通过使用 export TORCH_DEVICE_BACKEND_AUTOLOAD=0 来禁用。如果您收到类似“Failed to load the backend extension”(加载后端扩展失败)的错误,此错误与 PyTorch 无关,您应该禁用此功能并向树外扩展的维护者寻求帮助。
如何将此机制应用于树外扩展?#
例如,假设您有一个名为 foo 的后端和一个对应的包 torch_foo。请确保您的包与 PyTorch 2.5 或更高版本兼容,并在其 __init__.py 文件中包含以下代码片段
def _autoload():
print("Check things are working with `torch.foo.is_available()`.")
然后,您唯一需要做的就是在您的 Python 包中定义一个入口点(entry point)
setup(
name="torch_foo",
version="1.0",
entry_points={
"torch.backends": [
"torch_foo = torch_foo:_autoload",
],
}
)
现在,您只需添加 import torch 语句即可导入 torch_foo 模块,而无需添加 import torch_foo
>>> import torch
Check things are working with `torch.foo.is_available()`.
>>> torch.foo.is_available()
True
在某些情况下,您可能会遇到循环导入的问题。下面的示例演示了如何解决这些问题。
示例#
在本例中,我们将使用 Intel Gaudi HPU 和华为 Ascend NPU 来确定如何使用自动加载功能将您的树外扩展与 PyTorch 集成。
habana_frameworks.torch 是一个 Python 包,它允许用户通过使用 PyTorch HPU 设备键在 Intel Gaudi 上运行 PyTorch 程序。
habana_frameworks.torch 是 habana_frameworks 的子模块,我们在 habana_frameworks/setup.py 中为 __autoload() 添加了一个入口点
setup(
name="habana_frameworks",
version="2.5",
+ entry_points={
+ 'torch.backends': [
+ "device_backend = habana_frameworks:__autoload",
+ ],
+ }
)
在 habana_frameworks/init.py 中,我们使用一个全局变量来跟踪我们的模块是否已被加载
import os
is_loaded = False # A member variable of habana_frameworks module to track if our module has been imported
def __autoload():
# This is an entrypoint for pytorch autoload mechanism
# If the following condition is true, that means our backend has already been loaded, either explicitly
# or by the autoload mechanism and importing it again should be skipped to avoid circular imports
global is_loaded
if is_loaded:
return
import habana_frameworks.torch
在 habana_frameworks/torch/init.py 中,我们通过更新全局变量的状态来防止循环导入
import os
# This is to prevent torch autoload mechanism from causing circular imports
import habana_frameworks
habana_frameworks.is_loaded = True
torch_npu 使用户能够在华为 Ascend NPU 上运行 PyTorch 程序,它利用 PrivateUse1 设备键并将设备名称作为 npu 暴露给最终用户。
我们在 torch_npu/setup.py 中定义了一个入口点
setup(
name="torch_npu",
version="2.5",
+ entry_points={
+ 'torch.backends': [
+ 'torch_npu = torch_npu:_autoload',
+ ],
+ }
)
与 habana_frameworks 不同,torch_npu 使用环境变量 TORCH_DEVICE_BACKEND_AUTOLOAD 来控制自动加载过程。例如,我们将其设置为 0 以禁用自动加载,从而防止循环导入
# Disable autoloading before running 'import torch'
os.environ['TORCH_DEVICE_BACKEND_AUTOLOAD'] = '0'
import torch
工作原理#
自动加载基于 Python 的 入口点 (Entrypoints) 机制实现。我们在 torch/__init__.py 中发现并加载由树外扩展定义的所有特定入口点。
如上所示,安装 torch_foo 后,您的 Python 模块可以在加载您定义的入口点时被导入,然后您可以在调用它时执行一些必要的工作。
请参阅此 Pull Request 中的实现:[RFC] Add support for device extension autoloading。
结论#
在本教程中,我们了解了 PyTorch 中的树外扩展自动加载机制,该机制可自动加载后端扩展,无需添加额外的导入语句。我们还学习了如何通过定义入口点将此机制应用于树外扩展,以及如何防止循环导入。此外,我们还回顾了一个关于如何将自动加载机制与 Intel Gaudi HPU 和华为 Ascend NPU 结合使用的示例。