add support for choosing vulkan device (#439)

This commit is contained in:
Phaneesh Barwaria
2022-11-13 03:30:41 +05:30
committed by GitHub
parent 29a317dbb6
commit 749a2c2dec
6 changed files with 91 additions and 14 deletions

View File

@@ -13,7 +13,12 @@ if args.import_mlir:
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
extended_name = "{}_{}".format(model_name, args.device)
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print("Loading flatbuffer from {}".format(vmfb_path))

View File

@@ -37,7 +37,64 @@ def run_cmd(cmd):
sys.exit("Exiting program due to error running:", cmd)
IREE_DEVICE_MAP = {
def iree_device_map(device):
from iree.runtime import get_driver, get_device
def get_all_devices(driver_name):
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list = []
for device_dict in device_list_src:
device_list.append(f"{driver_name}://{device_dict['path']}")
device_list.sort()
return device_list
# only supported for vulkan as of now
if "vulkan://" in device:
device_list = get_all_devices("vulkan")
_, d_index = device.split("://")
matched_index = None
match_with_index = False
if 0 <= len(d_index) <= 2:
try:
d_index = int(d_index)
except:
print(
f"{d_index} is not valid index or uri. Will choose device 0"
)
d_index = 0
match_with_index = True
if len(device_list) > 1:
print("List of available vulkan devices:")
for i, d in enumerate(device_list):
print(f"vulkan://{i} => {d}")
if (match_with_index and d_index == i) or (
not match_with_index and d == device
):
matched_index = i
print(
f"Choosing device vulkan://{matched_index}\nTo choose another device please specify device index or uri accordingly."
)
return get_device(device_list[matched_index])
elif len(device_list) == 1:
print(f"Found one vulkan device: {device_list[0]}. Using this.")
return get_device(device_list[0])
else:
print(
f"No device found! returning device corresponding to driver name: vulkan"
)
return _IREE_DEVICE_MAP["vulkan"]
else:
return _IREE_DEVICE_MAP[device]
def get_supported_device_list():
return list(_IREE_DEVICE_MAP.keys())
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cuda": "cuda",
"vulkan": "vulkan",
@@ -46,7 +103,14 @@ IREE_DEVICE_MAP = {
"intel-gpu": "level_zero",
}
IREE_TARGET_MAP = {
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device]
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
@@ -58,6 +122,9 @@ IREE_TARGET_MAP = {
# Finds whether the required drivers are installed for the given device.
def check_device_drivers(device):
"""Checks necessary drivers present for gpu and vulkan devices"""
if "://" in device:
device = device.split("://")[0]
if device == "cuda":
try:
subprocess.check_output("nvidia-smi")

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
from shark.iree_utils._common import run_cmd, IREE_DEVICE_MAP
from shark.iree_utils._common import run_cmd, iree_device_map
import numpy as np
import os
import re
@@ -69,7 +69,7 @@ def build_benchmark_args(
# TODO: Replace name of train with actual train fn name.
fn_name = "train"
benchmark_cl.append(f"--entry_function={fn_name}")
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
benchmark_cl.append(f"--device={iree_device_map(device)}")
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
for mlir_input in mlir_input_types:
benchmark_cl.append(f"--function_input={mlir_input}")
@@ -96,7 +96,7 @@ def build_benchmark_args_non_tensor_input(
# TODO: The function named can be passed as one of the args.
if function_name:
benchmark_cl.append(f"--entry_function={function_name}")
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
benchmark_cl.append(f"--device={iree_device_map(device)}")
for input in inputs:
benchmark_cl.append(f"--function_input={input}")
time_extractor = "| awk 'END{{print $2 $3}}'"

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.benchmark_utils import *
import numpy as np
import os
@@ -224,7 +224,7 @@ def compile_module_to_flatbuffer(
# Currently for MHLO/TOSA.
flatbuffer_blob = ireec.compile_str(
module,
target_backends=[IREE_TARGET_MAP[device]],
target_backends=[iree_target_map(device)],
extra_args=args,
input_type=input_type,
)
@@ -232,7 +232,7 @@ def compile_module_to_flatbuffer(
# Currently for Torch.
flatbuffer_blob = ireec.compile_str(
module,
target_backends=[IREE_TARGET_MAP[device]],
target_backends=[iree_target_map(device)],
extra_args=args,
)
@@ -241,7 +241,12 @@ def compile_module_to_flatbuffer(
def get_iree_module(flatbuffer_blob, device, func_name):
# Returns the compiled module and the configs.
config = ireert.Config(IREE_DEVICE_MAP[device])
device = iree_device_map(device)
if type(device) == ireert.HalDevice:
config = ireert.Config(device=device)
else:
driver_name = device.split("://")[0] if "://" in device else device
config = ireert.Config(driver_name=driver_name)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
)
@@ -291,7 +296,8 @@ def export_iree_module_to_vmfb(
module, device, mlir_dialect, func_name, model_config_path, extra_args
)
if module_name is None:
module_name = f"{mlir_dialect}_{func_name}_{device}"
device_name = device.split("://")[0]
module_name = f"{mlir_dialect}_{func_name}_{device_name}"
filename = os.path.join(directory, module_name + ".vmfb")
print(f"Saved vmfb in {filename}.")
with open(filename, "wb") as f:

View File

@@ -87,7 +87,6 @@ class SharkImporter:
def _tflite_mlir(self, func_name, save_dir="./shark_tmp/"):
from iree.compiler import tflite as tflitec
from shark.iree_utils._common import IREE_TARGET_MAP
self.mlir_model = tflitec.compile_file(
self.raw_model_file, # in tflite, it is a path to .tflite file, not a tflite interpreter

View File

@@ -1,7 +1,7 @@
from shark.iree_utils._common import (
check_device_drivers,
device_driver_info,
IREE_DEVICE_MAP,
get_supported_device_list,
)
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
from parameterized import parameterized
@@ -59,7 +59,7 @@ def get_valid_test_params():
"""
device_list = [
device
for device in IREE_DEVICE_MAP.keys()
for device in get_supported_device_list()
if not check_device_drivers(device)
]
dynamic_list = (True, False)