Automatic selection of Target Triple from Sytem clang (#29)

-To enable running of models and further optimization on CPU that is not x86
This commit is contained in:
Stanley Winata
2022-04-29 15:49:37 -07:00
committed by GitHub
parent bd212634c1
commit b8602d0b64
2 changed files with 20 additions and 1 deletions

3
.gitignore vendored
View File

@@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Shark related artefacts
shark.venv/

View File

@@ -14,6 +14,7 @@
import iree.runtime as ireert
import iree.compiler as ireec
import subprocess
import numpy as np
import os
from shark.torch_mlir_utils import get_module_name_for_asm_dump
@@ -23,8 +24,23 @@ IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"}
def get_iree_compiled_module(module, device: str):
"""TODO: Documentation"""
args = ["--iree-llvm-target-cpu-features=host"]
if(device == "cpu"):
find_triple_cmd = "uname -s -m"
os_name, proc_name = subprocess.run(find_triple_cmd, shell=True, stdout=subprocess.PIPE, check=True).stdout.decode('utf-8').split()
if os_name == "Darwin":
find_kernel_version_cmd = "uname -r"
kernel_version = subprocess.run(find_kernel_version_cmd, shell=True, stdout=subprocess.PIPE, check=True).stdout.decode('utf-8')
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
elif os_name == "Linux":
target_triple = f"{proc_name}-linux-gnu"
else:
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
args.append(f"-iree-llvm-target-triple={target_triple}")
flatbuffer_blob = ireec.compile_str(
str(module), target_backends=[IREE_DEVICE_MAP[device]])
str(module), target_backends=[IREE_DEVICE_MAP[device]], extra_args=args)
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
config = ireert.Config(IREE_DEVICE_MAP[device])
ctx = ireert.SystemContext(config=config)