mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 05:47:54 -05:00
vulkan device id fix (#2028)
This commit is contained in:
committed by
GitHub
parent
7159698496
commit
bf70e80d20
@@ -1616,6 +1616,10 @@ class UnshardedVicuna(VicunaBase):
|
||||
self.vulkan_target_triple.split("-")[:-1]
|
||||
)
|
||||
differentiator = target_triple
|
||||
else:
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
tt = get_vulkan_triple_flag(device_num=self.device_id)
|
||||
differentiator = "_" + "_".join(tt.split("=")[1].split('-')[:-1])
|
||||
|
||||
elif "rocm" == self.device:
|
||||
from shark.iree_utils.gpu_utils import get_rocm_device_arch
|
||||
@@ -2355,6 +2359,10 @@ if __name__ == "__main__":
|
||||
break
|
||||
id += 1
|
||||
|
||||
if "://" in device :
|
||||
from shark.iree_utils.compile_utils import clean_device_info
|
||||
_, device_id = clean_device_info(args.device)
|
||||
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
|
||||
Reference in New Issue
Block a user