• 文档 >
  • 自定义硬件插件
快捷方式

自定义硬件插件

PyTorch/XLA 通过 OpenXLA 的 PJRT C API 支持自定义硬件。PyTorch/XLA 团队直接支持 Cloud TPU (libtpu) 和 GPU (OpenXLA) 的插件。JAX 和 TF 也可能使用相同的插件。

实现 PJRT 插件

PJRT C API 插件可以是闭源的,也可以是开源的。它们包含两个部分:

  1. 公开 PJRT C API 实现的二进制文件。这部分可以与 JAX 和 TensorFlow 共享。

  2. 包含上述二进制文件的 Python 包,以及我们 DevicePlugin Python 接口的实现,该接口负责额外的设置。

PJRT C API 实现

简而言之,您必须实现一个 PjRtClient,其中包含您设备的 XLA 编译器和运行时。PJRT C++ 接口在 PJRT_Api 中通过 C 进行了镜像。最直接的选择是用 C++ 实现您的插件,然后 将其包装 为 C API 实现。这个过程在 OpenXLA 的文档 中有详细解释。

有关具体示例,请参阅 示例实现

PyTorch/XLA 插件包

此时,您应该有一个功能正常的 PJRT 插件二进制文件,您可以使用占位符 LIBRARY 设备类型对其进行测试。例如:

$ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python
>>> import torch_xla
>>> torch_xla.devices()
# Assuming there are 4 devices. Your hardware may differ.
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]

要自动为用户注册您的设备类型,以及处理多进程等额外设置,您可以实现 DevicePlugin Python API。PyTorch/XLA 插件包包含两个关键组件:

  1. 一个 DevicePlugin 实现,该实现(至少)提供您的插件二进制文件的路径。例如:

class CpuPlugin(plugins.DevicePlugin):

  def library_path(self) -> str:
    return os.path.join(
        os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so')
  1. 一个 torch_xla.plugins 入口点,用于标识您的 DevicePlugin。例如,要在 pyproject.toml 中注册 EXAMPLE 设备类型:

<!-- -->
[project.entry-points."torch_xla.plugins"]
example = "torch_xla_cpu_plugin:CpuPlugin"

安装您的包后,您就可以直接使用您的 EXAMPLE 设备了:

$ PJRT_DEVICE=EXAMPLE python
>>> import torch_xla
>>> torch_xla.devices()
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]

DevicePlugin 提供了用于多进程初始化和客户端选项的附加扩展点。该 API 目前处于实验状态,但预计将在未来的版本中稳定。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源