mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Automatic selection of Target Triple from Sytem clang (#29)
-To enable running of models and further optimization on CPU that is not x86
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -158,3 +158,6 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Shark related artefacts
|
||||
shark.venv/
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import os
|
||||
from shark.torch_mlir_utils import get_module_name_for_asm_dump
|
||||
@@ -23,8 +24,23 @@ IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"}
|
||||
|
||||
def get_iree_compiled_module(module, device: str):
|
||||
"""TODO: Documentation"""
|
||||
args = ["--iree-llvm-target-cpu-features=host"]
|
||||
if(device == "cpu"):
|
||||
find_triple_cmd = "uname -s -m"
|
||||
os_name, proc_name = subprocess.run(find_triple_cmd, shell=True, stdout=subprocess.PIPE, check=True).stdout.decode('utf-8').split()
|
||||
if os_name == "Darwin":
|
||||
find_kernel_version_cmd = "uname -r"
|
||||
kernel_version = subprocess.run(find_kernel_version_cmd, shell=True, stdout=subprocess.PIPE, check=True).stdout.decode('utf-8')
|
||||
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
|
||||
elif os_name == "Linux":
|
||||
target_triple = f"{proc_name}-linux-gnu"
|
||||
else:
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
args.append(f"-iree-llvm-target-triple={target_triple}")
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
str(module), target_backends=[IREE_DEVICE_MAP[device]])
|
||||
str(module), target_backends=[IREE_DEVICE_MAP[device]], extra_args=args)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
|
||||
config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
|
||||
Reference in New Issue
Block a user