mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add run_on_refbackend and gpu configs. (#52)
`run_on_refbackend` is added to run linalg_on_tensors_backend compiled code on torch-mlir's refbackend. Also, added gpu configs and flags.
This commit is contained in:
@@ -52,7 +52,7 @@ def check_device_drivers(device):
|
||||
def get_iree_compiled_module(module, device: str):
|
||||
"""Given an mlir module returns the compiled .vmfb"""
|
||||
args = ["--iree-llvm-target-cpu-features=host"]
|
||||
if (device == "cpu"):
|
||||
if device == "cpu":
|
||||
find_triple_cmd = "uname -s -m"
|
||||
os_name, proc_name = subprocess.run(
|
||||
find_triple_cmd, shell=True, stdout=subprocess.PIPE,
|
||||
@@ -71,6 +71,12 @@ def get_iree_compiled_module(module, device: str):
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
args.append(f"-iree-llvm-target-triple={target_triple}")
|
||||
|
||||
if device == "gpu":
|
||||
args += ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa"]
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
ireert.flags.parse_flags("--cuda_allow_inline_execution")
|
||||
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
str(module), target_backends=[IREE_DEVICE_MAP[device]], extra_args=args)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file, run_on_refbackend
|
||||
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
|
||||
import os
|
||||
from shark.functorch_utils import AOTModule
|
||||
@@ -23,7 +23,7 @@ import time
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
"""TODO: Write the description"""
|
||||
"""Base class for Shark Inference and Shark Runner."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -55,6 +55,10 @@ class SharkRunner:
|
||||
return get_results(self.iree_compilation_module, input,
|
||||
self.iree_config)
|
||||
|
||||
# Refbackend is used only for debugging purpose. It can be quite slow.
|
||||
def run_on_refbackend(self, input):
|
||||
return run_on_refbackend(self.torch_mlir_module, input)
|
||||
|
||||
|
||||
class SharkInference:
|
||||
"""TODO: Write the description"""
|
||||
@@ -107,6 +111,9 @@ class SharkInference:
|
||||
input_list = [x.detach().numpy() for x in inputs]
|
||||
return self.shark_runner.forward(input_list)
|
||||
|
||||
def run_on_refbackend(self, inputs):
|
||||
self.shark_runner.run_on_refbackend(inputs)
|
||||
|
||||
|
||||
class SharkTrainer:
|
||||
"""TODO: Write the description"""
|
||||
|
||||
@@ -27,6 +27,8 @@ from torch_mlir_e2e_test.torchscript.serialization import (
|
||||
extract_serializable_annotations, apply_serializable_annotations,
|
||||
SerializableTest)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
from torch_mlir.ir import StringAttr
|
||||
@@ -66,6 +68,12 @@ def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
|
||||
annotations_list.append(tuple(temp_list))
|
||||
return annotations_list
|
||||
|
||||
def run_on_refbackend(torch_module, inputs):
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(torch_module)
|
||||
jit_module = backend.load(compiled)
|
||||
np_inputs = [x.numpy() for x in inputs]
|
||||
return jit_module.forward(np_inputs[0])
|
||||
|
||||
def shark_jit_trace(module, input: tuple, dynamic: bool,
|
||||
tracing_required: bool):
|
||||
|
||||
Reference in New Issue
Block a user