mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 23:08:19 -05:00
add support for choosing vulkan device (#439)
This commit is contained in:
committed by
GitHub
parent
29a317dbb6
commit
749a2c2dec
@@ -13,7 +13,12 @@ if args.import_mlir:
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
extended_name = "{}_{}".format(model_name, args.device)
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print("Loading flatbuffer from {}".format(vmfb_path))
|
||||
|
||||
@@ -37,7 +37,64 @@ def run_cmd(cmd):
|
||||
sys.exit("Exiting program due to error running:", cmd)
|
||||
|
||||
|
||||
IREE_DEVICE_MAP = {
|
||||
def iree_device_map(device):
|
||||
|
||||
from iree.runtime import get_driver, get_device
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list = []
|
||||
for device_dict in device_list_src:
|
||||
device_list.append(f"{driver_name}://{device_dict['path']}")
|
||||
device_list.sort()
|
||||
return device_list
|
||||
|
||||
# only supported for vulkan as of now
|
||||
if "vulkan://" in device:
|
||||
device_list = get_all_devices("vulkan")
|
||||
_, d_index = device.split("://")
|
||||
matched_index = None
|
||||
match_with_index = False
|
||||
if 0 <= len(d_index) <= 2:
|
||||
try:
|
||||
d_index = int(d_index)
|
||||
except:
|
||||
print(
|
||||
f"{d_index} is not valid index or uri. Will choose device 0"
|
||||
)
|
||||
d_index = 0
|
||||
match_with_index = True
|
||||
|
||||
if len(device_list) > 1:
|
||||
print("List of available vulkan devices:")
|
||||
for i, d in enumerate(device_list):
|
||||
print(f"vulkan://{i} => {d}")
|
||||
if (match_with_index and d_index == i) or (
|
||||
not match_with_index and d == device
|
||||
):
|
||||
matched_index = i
|
||||
print(
|
||||
f"Choosing device vulkan://{matched_index}\nTo choose another device please specify device index or uri accordingly."
|
||||
)
|
||||
return get_device(device_list[matched_index])
|
||||
elif len(device_list) == 1:
|
||||
print(f"Found one vulkan device: {device_list[0]}. Using this.")
|
||||
return get_device(device_list[0])
|
||||
else:
|
||||
print(
|
||||
f"No device found! returning device corresponding to driver name: vulkan"
|
||||
)
|
||||
return _IREE_DEVICE_MAP["vulkan"]
|
||||
else:
|
||||
return _IREE_DEVICE_MAP[device]
|
||||
|
||||
|
||||
def get_supported_device_list():
|
||||
return list(_IREE_DEVICE_MAP.keys())
|
||||
|
||||
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -46,7 +103,14 @@ IREE_DEVICE_MAP = {
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
IREE_TARGET_MAP = {
|
||||
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device]
|
||||
|
||||
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -58,6 +122,9 @@ IREE_TARGET_MAP = {
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
subprocess.check_output("nvidia-smi")
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
from shark.iree_utils._common import run_cmd, IREE_DEVICE_MAP
|
||||
from shark.iree_utils._common import run_cmd, iree_device_map
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
@@ -69,7 +69,7 @@ def build_benchmark_args(
|
||||
# TODO: Replace name of train with actual train fn name.
|
||||
fn_name = "train"
|
||||
benchmark_cl.append(f"--entry_function={fn_name}")
|
||||
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
|
||||
benchmark_cl.append(f"--device={iree_device_map(device)}")
|
||||
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
|
||||
for mlir_input in mlir_input_types:
|
||||
benchmark_cl.append(f"--function_input={mlir_input}")
|
||||
@@ -96,7 +96,7 @@ def build_benchmark_args_non_tensor_input(
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
if function_name:
|
||||
benchmark_cl.append(f"--entry_function={function_name}")
|
||||
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
|
||||
benchmark_cl.append(f"--device={iree_device_map(device)}")
|
||||
for input in inputs:
|
||||
benchmark_cl.append(f"--function_input={input}")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
|
||||
from shark.iree_utils._common import iree_device_map, iree_target_map
|
||||
from shark.iree_utils.benchmark_utils import *
|
||||
import numpy as np
|
||||
import os
|
||||
@@ -224,7 +224,7 @@ def compile_module_to_flatbuffer(
|
||||
# Currently for MHLO/TOSA.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[IREE_TARGET_MAP[device]],
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
input_type=input_type,
|
||||
)
|
||||
@@ -232,7 +232,7 @@ def compile_module_to_flatbuffer(
|
||||
# Currently for Torch.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[IREE_TARGET_MAP[device]],
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
)
|
||||
|
||||
@@ -241,7 +241,12 @@ def compile_module_to_flatbuffer(
|
||||
|
||||
def get_iree_module(flatbuffer_blob, device, func_name):
|
||||
# Returns the compiled module and the configs.
|
||||
config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
device = iree_device_map(device)
|
||||
if type(device) == ireert.HalDevice:
|
||||
config = ireert.Config(device=device)
|
||||
else:
|
||||
driver_name = device.split("://")[0] if "://" in device else device
|
||||
config = ireert.Config(driver_name=driver_name)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
config.vm_instance, flatbuffer_blob
|
||||
)
|
||||
@@ -291,7 +296,8 @@ def export_iree_module_to_vmfb(
|
||||
module, device, mlir_dialect, func_name, model_config_path, extra_args
|
||||
)
|
||||
if module_name is None:
|
||||
module_name = f"{mlir_dialect}_{func_name}_{device}"
|
||||
device_name = device.split("://")[0]
|
||||
module_name = f"{mlir_dialect}_{func_name}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
|
||||
@@ -87,7 +87,6 @@ class SharkImporter:
|
||||
|
||||
def _tflite_mlir(self, func_name, save_dir="./shark_tmp/"):
|
||||
from iree.compiler import tflite as tflitec
|
||||
from shark.iree_utils._common import IREE_TARGET_MAP
|
||||
|
||||
self.mlir_model = tflitec.compile_file(
|
||||
self.raw_model_file, # in tflite, it is a path to .tflite file, not a tflite interpreter
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from shark.iree_utils._common import (
|
||||
check_device_drivers,
|
||||
device_driver_info,
|
||||
IREE_DEVICE_MAP,
|
||||
get_supported_device_list,
|
||||
)
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
from parameterized import parameterized
|
||||
@@ -59,7 +59,7 @@ def get_valid_test_params():
|
||||
"""
|
||||
device_list = [
|
||||
device
|
||||
for device in IREE_DEVICE_MAP.keys()
|
||||
for device in get_supported_device_list()
|
||||
if not check_device_drivers(device)
|
||||
]
|
||||
dynamic_list = (True, False)
|
||||
|
||||
Reference in New Issue
Block a user