vulkan device id fix (#2028)

This commit is contained in:
Phaneesh Barwaria
2023-12-09 06:30:26 +05:30
committed by GitHub
parent 7159698496
commit bf70e80d20

View File

@@ -1616,6 +1616,10 @@ class UnshardedVicuna(VicunaBase):
self.vulkan_target_triple.split("-")[:-1]
)
differentiator = target_triple
else:
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
tt = get_vulkan_triple_flag(device_num=self.device_id)
differentiator = "_" + "_".join(tt.split("=")[1].split('-')[:-1])
elif "rocm" == self.device:
from shark.iree_utils.gpu_utils import get_rocm_device_arch
@@ -2355,6 +2359,10 @@ if __name__ == "__main__":
break
id += 1
if "://" in device :
from shark.iree_utils.compile_utils import clean_device_info
_, device_id = clean_device_info(args.device)
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"