mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
# Copyright 2020 The Nod Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from torch_mlir.ir import StringAttr
|
|
import torch_mlir
|
|
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
|
import tempfile
|
|
from amdshark.parser import amdshark_args
|
|
import io
|
|
|
|
mlir_type_mapping_dict = {
|
|
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
|
|
"stablehlo": torch_mlir.OutputType.STABLEHLO,
|
|
"tosa": torch_mlir.OutputType.TOSA,
|
|
}
|
|
|
|
|
|
def get_module_name_for_asm_dump(module):
|
|
"""Gets a name suitable for an assembly dump.
|
|
The name is not guaranteed to be unique.
|
|
"""
|
|
if not "torch.debug_module_name" in module.operation.attributes:
|
|
return "UnnammedModule"
|
|
return StringAttr(
|
|
module.operation.attributes["torch.debug_module_name"]
|
|
).value
|
|
|
|
|
|
def run_on_refbackend(torch_module, inputs):
|
|
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
|
compiled = backend.compile(torch_module)
|
|
jit_module = backend.load(compiled)
|
|
np_inputs = [x.numpy() for x in inputs]
|
|
return jit_module.forward(np_inputs[0])
|
|
|
|
|
|
# Creates dynamic dims for all dims.
|
|
# TODO: Pass user specified dynamic dims.
|
|
def create_dynamic_placeholders(inputs):
|
|
placeholders = []
|
|
for inp in inputs:
|
|
placeholder = torch_mlir.TensorPlaceholder.like(
|
|
inp, dynamic_axes=[i for i in range(len(inp.shape))]
|
|
)
|
|
placeholders.append(placeholder)
|
|
return tuple(placeholders)
|
|
|
|
|
|
def get_torch_mlir_module(
|
|
module,
|
|
input: tuple,
|
|
dynamic: bool,
|
|
jit_trace: bool,
|
|
return_str: bool = False,
|
|
mlir_type: str = "linalg",
|
|
):
|
|
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
|
|
ignore_traced_shapes = False
|
|
if dynamic:
|
|
input = create_dynamic_placeholders(input)
|
|
if jit_trace:
|
|
ignore_traced_shapes = True
|
|
|
|
tempfile.tempdir = "."
|
|
|
|
mlir_module = torch_mlir.compile(
|
|
module,
|
|
input,
|
|
output_type=mlir_type_mapping_dict[mlir_type],
|
|
use_tracing=jit_trace,
|
|
ignore_traced_shapes=ignore_traced_shapes,
|
|
)
|
|
|
|
if return_str:
|
|
return mlir_module.operation.get_asm()
|
|
bytecode_stream = io.BytesIO()
|
|
mlir_module.operation.write_bytecode(bytecode_stream)
|
|
bytecode = bytecode_stream.getvalue()
|
|
return bytecode
|