Torch 导出到 StableHLO¶
本文档介绍如何使用 torch export + torch xla 将模型导出为 StableHLO 格式。
完成此操作有 2 种方法
首先使用 torch.export 创建一个 ExportedProgram,其中包含 torch.fx 图中的程序。然后使用
exported_program_to_stablehlo
将其转换为包含 stablehlo MLIR 代码的对象。首先将 pytorch 模型转换为 jax 函数,然后使用 jax 工具将其转换为 stablehlo
使用 torch.export
¶
from torch.export import export
import torchvision
import torch
import torchax as tx
import torchax.export
resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)
weights, stablehlo = tx.export.exported_program_to_stablehlo(exported)
print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like
stablehlo 对象是 jax.export.Exported
类型。欢迎探索:https://openxla.org/stablehlo/tutorials/jax-export 以获取有关如何使用由此生成的 MLIR 代码的更多详细信息。
使用 extract_jax
¶
from torch.export import export
import torchvision
import torch
import torchax as tx
import torchax.export
import jax
import jax.numpy as jnp
resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
weights, jfunc = tx.extract_jax(resnet18)
# Below are APIs from jax
stablehlo = jax.export.export(jax.jit(jfunc))(weights, (jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype)))
print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like
在倒数第二行,我们使用 jax.ShapedDtypeStruct
来指定输入形状。您也可以在此处传递 numpy 数组。
在生成的 stablehlo 中内联某些权重¶
您可以通过导出调用模型的独立函数,将模型的部分或全部权重内联到 StableHLO 图中作为常量。
jax.jit
使用的约定是,被 jit 编译的 Python 函数的所有输入都将导出为参数,其他所有内容都将内联为常量。
因此,如上所示,我们导出的函数 jfunc
以 weights
和 args
作为输入,因此它们显示为参数。
如果您这样做
def jfunc_inlined(args):
return jfunc(weights, args)
然后导出/打印出其 stablehlo
print(jax.jit(jfunc_inlined).lower((jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype, ))))
那么,您将看到内联的常量。
通过生成 stablehlo.composite
来保留 StableHLO 中的高级 PyTorch 操作¶
高级 PyTorch 操作(例如 F.scaled_dot_product_attention
)将在 PyTorch -> StableHLO 降低过程中分解为低级操作。在下游 ML 编译器中捕获高级操作对于生成高性能、高效的专用内核至关重要。虽然在 ML 编译器中匹配大量低级操作可能具有挑战性且容易出错,但我们提供了一种更稳健的方法来在 StableHLO 程序中概述高级 PyTorch 操作——通过为高级 PyTorch 操作生成 stablehlo.composite。
以下示例展示了一个实际用例——捕获 scaled_product_attention
要使用 composite
,我们现在需要使用以 jax 为中心的导出(即,不使用 torch.export)。我们正在努力添加对 torch.export 的支持。
import unittest
import torch
import torch.nn.functional as F
from torch.library import Library, impl, impl_abstract
import torchax
import torchax.export
from torchax.ops import jaten
from torchax.ops import jlibrary
# Create a `mylib` library which has a basic SDPA op.
m = Library("mylib", "DEF")
m.define("scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor")
@impl(m, "scaled_dot_product_attention", "CompositeExplicitAutograd")
def _mylib_scaled_dot_product_attention(q, k, v):
"""Basic scaled dot product attention without all the flags/features."""
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
y = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=0,
is_causal=False,
scale=None,
)
return y.transpose(1, 2)
@impl_abstract("mylib::scaled_dot_product_attention")
def _mylib_scaled_dot_product_attention_meta(q, k, v):
return torch.empty_like(q)
# Register library op as a composite for export using the `@impl` method
# for a torch decomposition.
jlibrary.register_torch_composite(
"mylib.scaled_dot_product_attention",
_mylib_scaled_dot_product_attention,
torch.ops.mylib.scaled_dot_product_attention,
torch.ops.mylib.scaled_dot_product_attention.default
)
# Also register ATen softmax as a composite for export in the `mylib` library
# using the JAX ATen decomposition from `jaten`.
jlibrary.register_jax_composite(
"mylib.softmax",
jaten._aten_softmax,
torch.ops.aten._softmax,
static_argnums=1 # Required by JAX jit
)
class LibraryTest(unittest.TestCase):
def setUp(self):
torch.manual_seed(0)
torchax.default_env().config.use_torch_native_for_cpu_tensor = False
def test_basic_sdpa_library(self):
class CustomOpExample(torch.nn.Module):
def forward(self, q,k,v):
x = torch.ops.mylib.scaled_dot_product_attention(q, k, v)
x = x + 1
return x
# Export and check for composite operations
model = CustomOpExample()
arg = torch.rand(32, 8, 128, 64)
args = (arg, arg, arg, )
exported = torch.export.export(model, args=args)
stablehlo = torchax.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())
## TODO Update this machinery from producing function calls to producing
## stablehlo.composite ops.
self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
self.assertIn("call @mylib.softmax", module_str)
if __name__ == '__main__':
unittest.main()
如我们所见,要将 stablehlo 函数发出到 composite 中,首先我们需要创建一个 Python 函数来表示我们想要调用的代码区域,然后,我们将其注册,以便 pytorch 和 jlibrary 理解它是一个自定义区域。然后,导出的 StableHlo 将具有 mylib.scaled_dot_product_attention
和 mylib.softmax
概述的 stablehlo 函数。