mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Metal testing (#1595)
* Fixing metal_platform and device selection * fixing for metal platform * fixed for black lint formating
This commit is contained in:
committed by
GitHub
parent
788d469c5b
commit
9fcae4f808
@@ -291,7 +291,7 @@ def set_init_device_flags():
|
||||
if not args.iree_metal_target_platform:
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
args.iree_metal_target_platform = triple
|
||||
args.iree_metal_target_platform = triple.split("-")[-1]
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{args.iree_metal_target_platform}."
|
||||
|
||||
@@ -64,9 +64,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] == "metal":
|
||||
from shark.iree_utils.metal_utils import get_iree_metal_args
|
||||
|
||||
return get_iree_metal_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
return get_iree_metal_args(extra_args=extra_args)
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
|
||||
@@ -57,15 +57,7 @@ def get_metal_target_triple(device_name):
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
# Apple Targets
|
||||
if all(x in device_name for x in ("Apple", "M1")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
elif all(x in device_name for x in ("Apple", "M2")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
|
||||
else:
|
||||
triple = None
|
||||
return triple
|
||||
return "macos"
|
||||
|
||||
|
||||
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
@@ -81,7 +73,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
triple = get_metal_target_triple(metal_device)
|
||||
if triple is not None:
|
||||
print(
|
||||
f"Found metal device {metal_device}. Using metal target triple {triple}"
|
||||
f"Found metal device {metal_device}. Using metal target platform {triple}"
|
||||
)
|
||||
return f"-iree-metal-target-platform={triple}"
|
||||
print(
|
||||
@@ -105,12 +97,12 @@ def get_iree_metal_args(device_num=0, extra_args=[]):
|
||||
break
|
||||
|
||||
if metal_triple_flag is None:
|
||||
metal_triple_flag = get_metal_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)
|
||||
|
||||
if metal_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
|
||||
vulkan_target_env = get_vulkan_target_env_flag(
|
||||
"-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
)
|
||||
res_metal_flag.append(vulkan_target_env)
|
||||
return res_metal_flag
|
||||
|
||||
|
||||
Reference in New Issue
Block a user