mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
103 lines
3.3 KiB
Python
103 lines
3.3 KiB
Python
# Copyright 2023 The Nod Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# All the iree_vulkan related functionalities go here.
|
|
|
|
import functools
|
|
|
|
from shark.iree_utils._common import run_cmd
|
|
import iree.runtime as ireert
|
|
from sys import platform
|
|
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
|
|
|
|
|
@functools.cache
|
|
def get_metal_device_name(device_num=0):
|
|
iree_device_dump = run_cmd("iree-run-module --dump_devices")
|
|
iree_device_dump = iree_device_dump[0].split("\n\n")
|
|
metal_device_list = [
|
|
s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s
|
|
]
|
|
if len(metal_device_list) == 0:
|
|
raise ValueError("No device name found in device dump!")
|
|
if len(metal_device_list) > 1:
|
|
print("Following devices found:")
|
|
for i, dname in enumerate(metal_device_list):
|
|
print(f"{i}. {dname}")
|
|
print(f"Choosing device: {metal_device_list[device_num]}")
|
|
return metal_device_list[device_num]
|
|
|
|
|
|
def get_os_name():
|
|
if platform.startswith("linux"):
|
|
return "linux"
|
|
elif platform == "darwin":
|
|
return "macos"
|
|
elif platform == "win32":
|
|
return "windows"
|
|
else:
|
|
print("Cannot detect OS type, defaulting to linux.")
|
|
return "linux"
|
|
|
|
|
|
def get_metal_target_triple(device_name):
|
|
"""This method provides a target triple str for specified vulkan device.
|
|
|
|
Args:
|
|
device_name (str): name of the hardware device to be used with vulkan
|
|
|
|
Returns:
|
|
str or None: target triple or None if no match found for given name
|
|
"""
|
|
return "macos"
|
|
|
|
|
|
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
|
for flag in extra_args:
|
|
if "-iree-metal-target-platform=" in flag:
|
|
print(f"Using target triple {flag.split('=')[1]}")
|
|
return None
|
|
|
|
if device_name == "" or device_name == [] or device_name is None:
|
|
metal_device = get_metal_device_name(device_num=device_num)
|
|
else:
|
|
metal_device = device_name
|
|
triple = get_metal_target_triple(metal_device)
|
|
if triple is not None:
|
|
print(
|
|
f"Found metal device {metal_device}. Using metal target platform {triple}"
|
|
)
|
|
return f"-iree-metal-target-platform={triple}"
|
|
print(
|
|
"""Optimized kernel for your target device is not added yet.
|
|
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
|
or pull up an issue."""
|
|
)
|
|
print(f"Target : {metal_device}")
|
|
return None
|
|
|
|
|
|
def get_iree_metal_args(device_num=0, extra_args=[]):
|
|
# Add any metal spefic compilation flags here
|
|
res_metal_flag = []
|
|
if len(extra_args) > 0:
|
|
res_metal_flag.extend(extra_args)
|
|
return res_metal_flag
|
|
|
|
|
|
def set_iree_metal_runtime_flags(flags):
|
|
for flag in flags:
|
|
ireert.flags.parse_flags(flag)
|
|
return
|