Files
SHARK-Studio/amdshark/iree_utils/_common.py
pdhirajkumarprasad f6dd02fa67 Fix for migration (#2183)
Signed-off-by: pdhirajkumarprasad <dhirajp@amd.com>
2025-11-24 15:32:02 +05:30

165 lines
4.8 KiB
Python

# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## Common utilities to be shared by iree utilities.
import functools
import os
import sys
import subprocess
def run_cmd(cmd, debug=False, raise_err=False):
"""
Inputs:
cmd : cli command string.
debug : if True, prints debug info
raise_err : if True, raise exception to caller
"""
if debug:
print("IREE run command: \n\n")
print(cmd)
print("\n\n")
try:
result = subprocess.run(
cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
)
stdout = result.stdout.decode()
stderr = result.stderr.decode()
return stdout, stderr
except subprocess.CalledProcessError as e:
if raise_err:
raise Exception from e
else:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")
def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
_IREE_DEVICE_MAP[uri_parts[0]]
if uri_parts[0] in _IREE_DEVICE_MAP
else uri_parts[0]
)
if len(uri_parts) == 1:
return iree_driver
elif "rocm" in uri_parts:
return "rocm"
else:
return f"{iree_driver}://{uri_parts[1]}"
def get_supported_device_list():
return list(_IREE_DEVICE_MAP.keys())
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "metal",
"rocm": "rocm",
"hip": "hip",
"intel-gpu": "level_zero",
}
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan-spirv",
"metal": "metal",
"rocm": "rocm",
"hip": "rocm",
"intel-gpu": "opencl-spirv",
}
# Finds whether the required drivers are installed for the given device.
@functools.cache
def check_device_drivers(device):
"""
Checks necessary drivers present for gpu and vulkan devices
False => drivers present!
"""
if "://" in device:
device = device.split("://")[0]
from iree.runtime import get_driver
device_mapped = iree_device_map(device)
try:
_ = get_driver(device_mapped)
except ValueError as ve:
print(
f"[ERR] device `{device}` not registered with IREE. "
"Ensure IREE is configured for use with this device.\n"
f"Full Error: \n {repr(ve)}"
)
return True
except RuntimeError as re:
print(f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}")
return True
# Unknown device. We assume drivers are installed.
return False
# Installation info for the missing device drivers.
def device_driver_info(device):
device_driver_err_map = {
"cuda": {
"debug": "Try `nvidia-smi` on system to check.",
"solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.",
},
"vulkan": {
"debug": "Try `vulkaninfo` on system to check.",
"solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.",
},
"metal": {
"debug": "Check if Bare metal is supported and enabled on your system.",
"solution": ".",
},
"rocm": {
"debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.",
"solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.",
},
}
if device in device_driver_err_map:
err_msg = (
f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} "
f"Please install the required drivers{device_driver_err_map[device]['solution']} "
f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]"
f" and/or file a bug at https://github.com/nod-ai/AMD-SHARK-Studio/issues"
)
return err_msg
else:
return f"{device} is not supported."