Metal testing (#1595)

* Fixing metal_platform and device selection

* fixing for metal platform

* fixed for black lint formating
This commit is contained in:
Ranvir Singh Virk
2023-07-08 15:22:53 -07:00
committed by GitHub
parent 788d469c5b
commit 9fcae4f808
3 changed files with 8 additions and 18 deletions

View File

@@ -291,7 +291,7 @@ def set_init_device_flags():
if not args.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platform = triple
args.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{args.iree_metal_target_platform}."

View File

@@ -64,9 +64,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(
device_num=device_num, extra_args=extra_args
)
return get_iree_metal_args(extra_args=extra_args)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args

View File

@@ -57,15 +57,7 @@ def get_metal_target_triple(device_name):
Returns:
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"
else:
triple = None
return triple
return "macos"
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
@@ -81,7 +73,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
f"Found metal device {metal_device}. Using metal target platform {triple}"
)
return f"-iree-metal-target-platform={triple}"
print(
@@ -105,12 +97,12 @@ def get_iree_metal_args(device_num=0, extra_args=[]):
break
if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
vulkan_target_env = get_vulkan_target_env_flag(
"-iree-vulkan-target-triple=m1-moltenvk-macos"
)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag