Torch 导出到 StableHLO¶
本文档介绍了如何使用 torch export + torch xla 将模型导出为 StableHLO 格式。
有两种方法可以实现此目的
首先,进行 torch.export 创建一个 ExportedProgram,其中包含 torch.fx 图中的程序。然后使用
exported_program_to_stablehlo
将其转换为包含 stablehlo MLIR 代码的对象。首先将 pytorch 模型转换为 jax 函数,然后使用 jax 工具将其转换为 stablehlo
使用 torch.export
¶
from torch.export import export
import torchvision
import torch
import torch_xla2 as tx
import torch_xla2.export
resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)
weights, stablehlo = tx.export.exported_program_to_stablehlo(exported)
print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like
stablehlo 对象是 jax.export.Exported
类型。欢迎探索: https://openxla.org/stablehlo/tutorials/jax-export 以了解有关如何使用由此生成的 MLIR 代码的更多详细信息。
使用 extract_jax
¶
from torch.export import export
import torchvision
import torch
import torch_xla2 as tx
import torch_xla2.export
import jax
import jax.numpy as jnp
resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
weights, jfunc = tx.extract_jax(resnet18)
# Below are APIs from jax
stablehlo = jax.export.export(jax.jit(jfunc))(weights, (jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype)))
print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like
倒数第二行我们使用了 jax.ShapedDtypeStruct
来指定输入形状。您也可以在此处传递 numpy 数组。
通过生成 stablehlo.composite
在 StableHLO 中保留高级 PyTorch 操作¶
高级 PyTorch op(例如 F.scaled_dot_product_attention
)将在 PyTorch -> StableHLO 降低过程中分解为低级 op。在 ML 编译器中捕获高级 op 对于生成高性能、高效的专用内核至关重要。虽然在 ML 编译器中匹配大量低级 op 可能具有挑战性且容易出错,但我们提供了一种更强大的方法来在 StableHLO 程序中概述高级 PyTorch op - 通过为高级 PyTorch op 生成 stablehlo.composite。
以下示例展示了一个实际用例 - 捕获 scaled_product_attention
要使用 composite
,我们现在需要使用以 jax 为中心的导出(即,不使用 torch.export)。我们正在努力添加对 torch.export 的支持。
import unittest
import torch
import torch.nn.functional as F
from torch.library import Library, impl, impl_abstract
import torch_xla2
import torch_xla2.export
from torch_xla2.ops import jaten
from torch_xla2.ops import jlibrary
# Create a `mylib` library which has a basic SDPA op.
m = Library("mylib", "DEF")
m.define("scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor")
@impl(m, "scaled_dot_product_attention", "CompositeExplicitAutograd")
def _mylib_scaled_dot_product_attention(q, k, v):
"""Basic scaled dot product attention without all the flags/features."""
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
y = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=0,
is_causal=False,
scale=None,
)
return y.transpose(1, 2)
@impl_abstract("mylib::scaled_dot_product_attention")
def _mylib_scaled_dot_product_attention_meta(q, k, v):
return torch.empty_like(q)
# Register library op as a composite for export using the `@impl` method
# for a torch decomposition.
jlibrary.register_torch_composite(
"mylib.scaled_dot_product_attention",
_mylib_scaled_dot_product_attention,
torch.ops.mylib.scaled_dot_product_attention,
torch.ops.mylib.scaled_dot_product_attention.default
)
# Also register ATen softmax as a composite for export in the `mylib` library
# using the JAX ATen decomposition from `jaten`.
jlibrary.register_jax_composite(
"mylib.softmax",
jaten._aten_softmax,
torch.ops.aten._softmax,
static_argnums=1 # Required by JAX jit
)
class LibraryTest(unittest.TestCase):
def setUp(self):
torch.manual_seed(0)
torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False
def test_basic_sdpa_library(self):
class CustomOpExample(torch.nn.Module):
def forward(self, q,k,v):
x = torch.ops.mylib.scaled_dot_product_attention(q, k, v)
x = x + 1
return x
# Export and check for composite operations
model = CustomOpExample()
arg = torch.rand(32, 8, 128, 64)
args = (arg, arg, arg, )
exported = torch.export.export(model, args=args)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())
## TODO Update this machinery from producing function calls to producing
## stablehlo.composite ops.
self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
self.assertIn("call @mylib.softmax", module_str)
if __name__ == '__main__':
unittest.main()
正如我们所见,要将 stablehlo 函数发出到 composite,首先我们创建一个代表我们想要调用的代码区域的 python 函数,然后,我们注册它,以便 pytorch 和 jlibrary 了解它是一个自定义区域。然后,发出的 Stablehlo 将具有 mylib.scaled_dot_product_attention
和 mylib.softmax
概述的 stablehlo 函数。