mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
take all ireert calls out of studio flow
This commit is contained in:
@@ -65,6 +65,7 @@ _IREE_TARGET_MAP = {
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
return ['rocm', 'cpu']
|
||||
def get_devices_by_name(driver_name):
|
||||
|
||||
device_list = []
|
||||
@@ -225,65 +226,65 @@ def get_all_devices(driver_name):
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired
|
||||
combination of name/path.
|
||||
"""
|
||||
# def get_device_mapping(driver, key_combination=3):
|
||||
# """This method ensures consistent device ordering when choosing
|
||||
# specific devices for execution
|
||||
# Args:
|
||||
# driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
# key_combination (int, optional): choice for mapping value for
|
||||
# device name.
|
||||
# 1 : path
|
||||
# 2 : name
|
||||
# 3 : (name, path)
|
||||
# Defaults to 3.
|
||||
# Returns:
|
||||
# dict: map to possible device names user can input mapped to desired
|
||||
# combination of name/path.
|
||||
# """
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
# driver = iree_device_map(driver)
|
||||
# device_list = get_all_devices(driver)
|
||||
# device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
|
||||
# def get_output_value(dev_dict):
|
||||
# if key_combination == 1:
|
||||
# return f"{driver}://{dev_dict['path']}"
|
||||
# if key_combination == 2:
|
||||
# return dev_dict["name"]
|
||||
# if key_combination == 3:
|
||||
# return dev_dict["name"], f"{driver}://{dev_dict['path']}"
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
# # mapping driver name to default device (driver://0)
|
||||
# device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
# for i, device in enumerate(device_list):
|
||||
# # mapping with index
|
||||
# device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# # mapping with full path
|
||||
# device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
# return device_map
|
||||
|
||||
|
||||
def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags = []
|
||||
if len(cmd_opts.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
)
|
||||
if "rocm" in cmd_opts.device:
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
# def get_opt_flags(model, precision="fp16"):
|
||||
# iree_flags = []
|
||||
# if len(cmd_opts.iree_vulkan_target_triple) > 0:
|
||||
# iree_flags.append(
|
||||
# f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
# )
|
||||
# if "rocm" in cmd_opts.device:
|
||||
# from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
if cmd_opts.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
if cmd_opts.data_tiling == False:
|
||||
iree_flags.append("--iree-opt-data-tiling=False")
|
||||
# rocm_args = get_iree_rocm_args()
|
||||
# iree_flags.extend(rocm_args)
|
||||
# if cmd_opts.iree_constant_folding == False:
|
||||
# iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
# iree_flags.append(
|
||||
# "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
# )
|
||||
# if cmd_opts.data_tiling == False:
|
||||
# iree_flags.append("--iree-opt-data-tiling=False")
|
||||
|
||||
if "vae" not in model:
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
# if "vae" not in model:
|
||||
# # Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# # dims before dispatch formation right now.
|
||||
# iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
# return iree_flags
|
||||
Reference in New Issue
Block a user