update dependency on rocm/hip info command (#1900)

* add support for rocm flags

* add rocm target flag to chat args

* rm rocm libs dependency message
This commit is contained in:
Phaneesh Barwaria
2023-10-26 15:18:25 +05:30
committed by GitHub
parent 0c38c33d0a
commit 486202377a
5 changed files with 34 additions and 51 deletions

View File

@@ -725,6 +725,17 @@ p.add_argument(
help="Specifies whether the docuchat's web version is running or not.",
)
##############################################################################
# rocm Flags
##############################################################################
p.add_argument(
"--iree_rocm_target_chip",
type=str,
default="gfx1100",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Default gfx1100",
)
args, unknown = p.parse_known_args()
if args.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(

View File

@@ -216,8 +216,14 @@ def chat(
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")
print(f"Will use target triple : {vulkan_target_triple}")
elif "rocm" in device:
# add iree rocm flags
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(

View File

@@ -120,14 +120,8 @@ def check_device_drivers(device):
elif device == "cpu":
return False
elif device == "rocm":
try:
if sys.platform == "win32":
subprocess.check_output("hipinfo")
else:
subprocess.check_output("rocminfo")
except Exception:
return True
# Required ROCm driver libs are already part of IREE
return False
# Unknown device. We assume drivers are installed.
return False

View File

@@ -75,7 +75,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args()
return get_iree_rocm_args(extra_args=extra_args)
return []

View File

@@ -41,54 +41,26 @@ def get_iree_gpu_args():
# Get the default gpu args given the architecture.
@functools.cache
def get_iree_rocm_args():
def get_iree_rocm_args(extra_args=[]):
ireert.flags.FUNCTION_INPUT_VALIDATION = False
# get arch from hipinfo.
import os
import re
import subprocess
rocm_flags = ["--iree-rocm-link-bc=true"]
if sys.platform == "win32":
if "HIP_PATH" in os.environ:
rocm_path = os.environ["HIP_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
rocm_path = "C:\\AMD\\ROCM\\5.5"
else:
if "ROCM_PATH" in os.environ:
rocm_path = os.environ["ROCM_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
rocm_path = "/opt/rocm/"
try:
if sys.platform == "win32":
rocm_arch = re.search(
r"gfx\d{3,}",
subprocess.check_output("hipinfo", shell=True, text=True),
).group(0)
else:
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
except:
# Add the target arch flag for rocm device
flag_present = False
for flag in extra_args:
if "iree-rocm-target-chip" in flag:
flag_present = True
print(
f"found rocm target device arch from flag : {flag.split('=')[1]}"
)
if not flag_present:
print(
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
)
rocm_arch = "gfx1100"
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
return [
f"--iree-rocm-target-chip={rocm_arch}",
"--iree-rocm-link-bc=true",
f"--iree-rocm-bc-dir={bc_path}",
]
return rocm_flags
# Some constants taken from cuda.h