Add support for torch to stablehlo and tosa in shark_importer (#1360)

This commit is contained in:
Chi_Liu
2023-04-27 08:09:45 -07:00
committed by GitHub
parent 1db906a373
commit aa8ada9da9
2 changed files with 16 additions and 3 deletions

View File

@@ -81,7 +81,7 @@ class SharkImporter:
# NOTE: The default function for torch is "forward" and tf-lite is "main".
def _torch_mlir(self, is_dynamic, tracing_required):
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
@@ -90,6 +90,7 @@ class SharkImporter:
is_dynamic,
tracing_required,
self.return_str,
mlir_type,
)
def _tf_mlir(self, func_name, save_dir="."):
@@ -120,6 +121,7 @@ class SharkImporter:
tracing_required=False,
func_name="forward",
save_dir="./shark_tmp/",
mlir_type="linalg",
):
if self.frontend in ["torch", "pytorch"]:
if self.inputs == None:
@@ -127,7 +129,10 @@ class SharkImporter:
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
)
sys.exit(1)
return self._torch_mlir(is_dynamic, tracing_required), func_name
return (
self._torch_mlir(is_dynamic, tracing_required, mlir_type),
func_name,
)
if self.frontend in ["tf", "tensorflow"]:
return self._tf_mlir(func_name, save_dir), func_name
if self.frontend in ["tflite", "tf-lite"]:

View File

@@ -19,6 +19,12 @@ import tempfile
from shark.parser import shark_args
import io
mlir_type_mapping_dict = {
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
"mhlo": torch_mlir.OutputType.STABLEHLO,
"tosa": torch_mlir.OutputType.TOSA,
}
def get_module_name_for_asm_dump(module):
"""Gets a name suitable for an assembly dump.
@@ -57,6 +63,7 @@ def get_torch_mlir_module(
dynamic: bool,
jit_trace: bool,
return_str: bool = False,
mlir_type: str = "linalg",
):
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
ignore_traced_shapes = False
@@ -70,10 +77,11 @@ def get_torch_mlir_module(
mlir_module = torch_mlir.compile(
module,
input,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
output_type=mlir_type_mapping_dict[mlir_type],
use_tracing=jit_trace,
ignore_traced_shapes=ignore_traced_shapes,
)
if return_str:
return mlir_module.operation.get_asm()
bytecode_stream = io.BytesIO()