From bf70e80d2024dfdab2bd4472c40e91f42e26be38 Mon Sep 17 00:00:00 2001 From: Phaneesh Barwaria Date: Sat, 9 Dec 2023 06:30:26 +0530 Subject: [PATCH] vulkan device id fix (#2028) --- apps/language_models/scripts/vicuna.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 544c3fde..548dcb83 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1616,6 +1616,10 @@ class UnshardedVicuna(VicunaBase): self.vulkan_target_triple.split("-")[:-1] ) differentiator = target_triple + else: + from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag + tt = get_vulkan_triple_flag(device_num=self.device_id) + differentiator = "_" + "_".join(tt.split("=")[1].split('-')[:-1]) elif "rocm" == self.device: from shark.iree_utils.gpu_utils import get_rocm_device_arch @@ -2355,6 +2359,10 @@ if __name__ == "__main__": break id += 1 + if "://" in device : + from shark.iree_utils.compile_utils import clean_device_info + _, device_id = clean_device_info(args.device) + assert ( device_id ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"