• 文档 >
  • Torch 导出到 StableHLO
快捷方式

Torch 导出到 StableHLO

本文档介绍了如何使用 torch export + torch xla 将模型导出为 StableHLO 格式。

有两种方法可以实现此目的

  1. 首先,进行 torch.export 创建一个 ExportedProgram,其中包含 torch.fx 图中的程序。然后使用 exported_program_to_stablehlo 将其转换为包含 stablehlo MLIR 代码的对象。

  2. 首先将 pytorch 模型转换为 jax 函数,然后使用 jax 工具将其转换为 stablehlo

使用 torch.export

from torch.export import export
import torchvision
import torch
import torch_xla2 as tx
import torch_xla2.export

resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)

weights, stablehlo = tx.export.exported_program_to_stablehlo(exported)
print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like

stablehlo 对象是 jax.export.Exported 类型。欢迎探索: https://openxla.org/stablehlo/tutorials/jax-export 以了解有关如何使用由此生成的 MLIR 代码的更多详细信息。

使用 extract_jax

from torch.export import export
import torchvision
import torch
import torch_xla2 as tx
import torch_xla2.export
import jax
import jax.numpy as jnp

resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)

weights, jfunc = tx.extract_jax(resnet18)

# Below are APIs from jax

stablehlo = jax.export.export(jax.jit(jfunc))(weights, (jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype)))

print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like

倒数第二行我们使用了 jax.ShapedDtypeStruct 来指定输入形状。您也可以在此处传递 numpy 数组。

通过生成 stablehlo.composite 在 StableHLO 中保留高级 PyTorch 操作

高级 PyTorch op(例如 F.scaled_dot_product_attention)将在 PyTorch -> StableHLO 降低过程中分解为低级 op。在 ML 编译器中捕获高级 op 对于生成高性能、高效的专用内核至关重要。虽然在 ML 编译器中匹配大量低级 op 可能具有挑战性且容易出错,但我们提供了一种更强大的方法来在 StableHLO 程序中概述高级 PyTorch op - 通过为高级 PyTorch op 生成 stablehlo.composite

以下示例展示了一个实际用例 - 捕获 scaled_product_attention

要使用 composite,我们现在需要使用以 jax 为中心的导出(即,不使用 torch.export)。我们正在努力添加对 torch.export 的支持。

import unittest
import torch
import torch.nn.functional as F
from torch.library import Library, impl, impl_abstract
import torch_xla2
import torch_xla2.export
from torch_xla2.ops import jaten
from torch_xla2.ops import jlibrary


# Create a `mylib` library which has a basic SDPA op.
m = Library("mylib", "DEF")
m.define("scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor")

@impl(m, "scaled_dot_product_attention", "CompositeExplicitAutograd")
def _mylib_scaled_dot_product_attention(q, k, v):
  """Basic scaled dot product attention without all the flags/features."""
  q = q.transpose(1, 2)
  k = k.transpose(1, 2)
  v = v.transpose(1, 2)
  y = F.scaled_dot_product_attention(
      q,
      k,
      v,
      dropout_p=0,
      is_causal=False,
      scale=None,
  )
  return y.transpose(1, 2)

@impl_abstract("mylib::scaled_dot_product_attention")
def _mylib_scaled_dot_product_attention_meta(q, k, v):
  return torch.empty_like(q)

# Register library op as a composite for export using the `@impl` method
# for a torch decomposition.
jlibrary.register_torch_composite(
  "mylib.scaled_dot_product_attention",
  _mylib_scaled_dot_product_attention,
  torch.ops.mylib.scaled_dot_product_attention,
  torch.ops.mylib.scaled_dot_product_attention.default
)

# Also register ATen softmax as a composite for export in the `mylib` library
# using the JAX ATen decomposition from `jaten`.
jlibrary.register_jax_composite(
  "mylib.softmax",
  jaten._aten_softmax,
  torch.ops.aten._softmax,
  static_argnums=1  # Required by JAX jit
)

class LibraryTest(unittest.TestCase):

  def setUp(self):
    torch.manual_seed(0)
    torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False

  def test_basic_sdpa_library(self):

    class CustomOpExample(torch.nn.Module):
      def forward(self, q,k,v):
        x = torch.ops.mylib.scaled_dot_product_attention(q, k, v)
        x = x + 1
        return x

    # Export and check for composite operations
    model = CustomOpExample()
    arg = torch.rand(32, 8, 128, 64)
    args = (arg, arg, arg, )

    exported = torch.export.export(model, args=args)
    stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
    module_str = str(stablehlo.mlir_module())

    ## TODO Update this machinery from producing function calls to producing
    ## stablehlo.composite ops.
    self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
    self.assertIn("call @mylib.softmax", module_str)


if __name__ == '__main__':
  unittest.main()

正如我们所见,要将 stablehlo 函数发出到 composite,首先我们创建一个代表我们想要调用的代码区域的 python 函数,然后,我们注册它,以便 pytorch 和 jlibrary 了解它是一个自定义区域。然后,发出的 Stablehlo 将具有 mylib.scaled_dot_product_attentionmylib.softmax 概述的 stablehlo 函数。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源