• 文档 >
  • Torch 导出到 StableHLO
快捷方式

Torch 导出到 StableHLO

本文档介绍如何使用 torch export + torch xla 将模型导出为 StableHLO 格式。

完成此操作有 2 种方法

  1. 首先使用 torch.export 创建一个 ExportedProgram,其中包含 torch.fx 图中的程序。然后使用 exported_program_to_stablehlo 将其转换为包含 stablehlo MLIR 代码的对象。

  2. 首先将 pytorch 模型转换为 jax 函数,然后使用 jax 工具将其转换为 stablehlo

使用 torch.export

from torch.export import export
import torchvision
import torch
import torchax as tx
import torchax.export

resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)

weights, stablehlo = tx.export.exported_program_to_stablehlo(exported)
print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like

stablehlo 对象是 jax.export.Exported 类型。欢迎探索:https://openxla.org/stablehlo/tutorials/jax-export 以获取有关如何使用由此生成的 MLIR 代码的更多详细信息。

使用 extract_jax

from torch.export import export
import torchvision
import torch
import torchax as tx
import torchax.export
import jax
import jax.numpy as jnp

resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)

weights, jfunc = tx.extract_jax(resnet18)

# Below are APIs from jax

stablehlo = jax.export.export(jax.jit(jfunc))(weights, (jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype)))

print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like

在倒数第二行,我们使用 jax.ShapedDtypeStruct 来指定输入形状。您也可以在此处传递 numpy 数组。

在生成的 stablehlo 中内联某些权重

您可以通过导出调用模型的独立函数,将模型的部分或全部权重内联到 StableHLO 图中作为常量。

jax.jit 使用的约定是,被 jit 编译的 Python 函数的所有输入都将导出为参数,其他所有内容都将内联为常量。

因此,如上所示,我们导出的函数 jfuncweightsargs 作为输入,因此它们显示为参数。

如果您这样做

def jfunc_inlined(args):
  return jfunc(weights, args)

然后导出/打印出其 stablehlo

print(jax.jit(jfunc_inlined).lower((jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype, ))))

那么,您将看到内联的常量。

通过生成 stablehlo.composite 来保留 StableHLO 中的高级 PyTorch 操作

高级 PyTorch 操作(例如 F.scaled_dot_product_attention)将在 PyTorch -> StableHLO 降低过程中分解为低级操作。在下游 ML 编译器中捕获高级操作对于生成高性能、高效的专用内核至关重要。虽然在 ML 编译器中匹配大量低级操作可能具有挑战性且容易出错,但我们提供了一种更稳健的方法来在 StableHLO 程序中概述高级 PyTorch 操作——通过为高级 PyTorch 操作生成 stablehlo.composite

以下示例展示了一个实际用例——捕获 scaled_product_attention

要使用 composite,我们现在需要使用以 jax 为中心的导出(即,不使用 torch.export)。我们正在努力添加对 torch.export 的支持。

import unittest
import torch
import torch.nn.functional as F
from torch.library import Library, impl, impl_abstract
import torchax
import torchax.export
from torchax.ops import jaten
from torchax.ops import jlibrary


# Create a `mylib` library which has a basic SDPA op.
m = Library("mylib", "DEF")
m.define("scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor")

@impl(m, "scaled_dot_product_attention", "CompositeExplicitAutograd")
def _mylib_scaled_dot_product_attention(q, k, v):
  """Basic scaled dot product attention without all the flags/features."""
  q = q.transpose(1, 2)
  k = k.transpose(1, 2)
  v = v.transpose(1, 2)
  y = F.scaled_dot_product_attention(
      q,
      k,
      v,
      dropout_p=0,
      is_causal=False,
      scale=None,
  )
  return y.transpose(1, 2)

@impl_abstract("mylib::scaled_dot_product_attention")
def _mylib_scaled_dot_product_attention_meta(q, k, v):
  return torch.empty_like(q)

# Register library op as a composite for export using the `@impl` method
# for a torch decomposition.
jlibrary.register_torch_composite(
  "mylib.scaled_dot_product_attention",
  _mylib_scaled_dot_product_attention,
  torch.ops.mylib.scaled_dot_product_attention,
  torch.ops.mylib.scaled_dot_product_attention.default
)

# Also register ATen softmax as a composite for export in the `mylib` library
# using the JAX ATen decomposition from `jaten`.
jlibrary.register_jax_composite(
  "mylib.softmax",
  jaten._aten_softmax,
  torch.ops.aten._softmax,
  static_argnums=1  # Required by JAX jit
)

class LibraryTest(unittest.TestCase):

  def setUp(self):
    torch.manual_seed(0)
    torchax.default_env().config.use_torch_native_for_cpu_tensor = False

  def test_basic_sdpa_library(self):

    class CustomOpExample(torch.nn.Module):
      def forward(self, q,k,v):
        x = torch.ops.mylib.scaled_dot_product_attention(q, k, v)
        x = x + 1
        return x

    # Export and check for composite operations
    model = CustomOpExample()
    arg = torch.rand(32, 8, 128, 64)
    args = (arg, arg, arg, )

    exported = torch.export.export(model, args=args)
    stablehlo = torchax.export.exported_program_to_stablehlo(exported)
    module_str = str(stablehlo.mlir_module())

    ## TODO Update this machinery from producing function calls to producing
    ## stablehlo.composite ops.
    self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
    self.assertIn("call @mylib.softmax", module_str)


if __name__ == '__main__':
  unittest.main()

如我们所见,要将 stablehlo 函数发出到 composite 中,首先我们需要创建一个 Python 函数来表示我们想要调用的代码区域,然后,我们将其注册,以便 pytorch 和 jlibrary 理解它是一个自定义区域。然后,导出的 StableHlo 将具有 mylib.scaled_dot_product_attentionmylib.softmax 概述的 stablehlo 函数。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源