diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index a1f57e1a..e0675f77 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1229,6 +1229,7 @@ class UnshardedVicuna(VicunaBase): download_vmfb=False, cache_vicunas=False, extra_args_cmd=[], + device_id=None, debug=False, ) -> None: super().__init__( @@ -1247,6 +1248,7 @@ class UnshardedVicuna(VicunaBase): print(f"[DEBUG] hf model name: {self.hf_model_path}") self.max_sequence_length = 256 self.device = device + self.device_id = device_id self.precision = precision self.download_vmfb = download_vmfb self.vicuna_vmfb_path = vicuna_vmfb_path @@ -1409,7 +1411,7 @@ class UnshardedVicuna(VicunaBase): single_file=True, ) self.shark_model = get_vmfb_from_path( - self.vicuna_vmfb_path, self.device, "tm_tensor" + self.vicuna_vmfb_path, self.device, "tm_tensor", self.device_id ) if self.shark_model is not None: print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}") @@ -1657,6 +1659,7 @@ class UnshardedVicuna(VicunaBase): mlir_module=combined_module, device=self.device, mlir_dialect="tm_tensor", + device_idx=self.device_id ) path = shark_module.save_module( self.vicuna_vmfb_path.parent.absolute(), @@ -1808,11 +1811,37 @@ if __name__ == "__main__": args, unknown = parser.parse_known_args() _extra_args = [] - # vulkan target triple - if args.iree_vulkan_target_triple != "": + device_id = None + # Process vulkan target triple. + # TODO: This feature should just be in a common utils for other LLMs and in general + # any model run via SHARK for Vulkan backend. + vulkan_target_triple = args.iree_vulkan_target_triple + if vulkan_target_triple != "": _extra_args.append( f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" ) + # Step 1. Fetch the device ID. + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + get_vulkan_target_triple + ) + vulkaninfo_list = get_all_vulkan_devices() + id = 0 + for device in vulkaninfo_list: + target_triple = get_vulkan_target_triple(vulkaninfo_list[id]) + if target_triple == vulkan_target_triple: + device_id = id + break + id += 1 + + assert device_id, f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" + # Step 2. Add a few flags targetting specific hardwares. + if "rdna" in vulkan_target_triple: + flags_to_add = [ + "--iree-spirv-index-bits=64", + ] + _extra_args = _extra_args + flags_to_add + vic = None if not args.sharded: @@ -1838,6 +1867,7 @@ if __name__ == "__main__": download_vmfb=args.download_vmfb, cache_vicunas=args.cache_vicunas, extra_args_cmd=_extra_args, + device_id=device_id ) else: if args.config is not None: diff --git a/apps/language_models/utils.py b/apps/language_models/utils.py index c9892ed5..20bebdf9 100644 --- a/apps/language_models/utils.py +++ b/apps/language_models/utils.py @@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file # expects a Path / str as arg # returns None if path not found or SharkInference module -def get_vmfb_from_path(vmfb_path, device, mlir_dialect): +def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None): if not isinstance(vmfb_path, Path): vmfb_path = Path(vmfb_path) @@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect): print("Loading vmfb from: ", vmfb_path) print("Device from get_vmfb_from_path - ", device) shark_module = SharkInference( - None, device=device, mlir_dialect=mlir_dialect + None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id ) shark_module.load_module(vmfb_path) print("Successfully loaded vmfb") @@ -28,7 +28,13 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect): def get_vmfb_from_config( - shark_container, model, precision, device, vmfb_path, padding=None + shark_container, + model, + precision, + device, + vmfb_path, + padding=None, + device_id=None, ): vmfb_url = ( f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}" @@ -37,4 +43,6 @@ def get_vmfb_from_config( vmfb_url = vmfb_url + f"_{padding}" vmfb_url = vmfb_url + ".vmfb" download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True) - return get_vmfb_from_path(vmfb_path, device, "tm_tensor") + return get_vmfb_from_path( + vmfb_path, device, "tm_tensor", device_id=device_id + ) diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index d514f1d8..641c6579 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -470,7 +470,21 @@ def get_available_devices(): set_iree_runtime_flags() available_devices = [] - vulkan_devices = get_devices_by_name("vulkan") + from shark.iree_utils._common import run_cmd + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + ) + + vulkaninfo_list = get_all_vulkan_devices() + vulkan_devices = [] + id = 0 + for device in vulkaninfo_list: + vulkan_devices.append( + f"{device.split('=')[1].strip()} => vulkan://{id}" + ) + id += 1 + if id != 0: + print(f"vulkan devices are available.") available_devices.extend(vulkan_devices) metal_devices = get_devices_by_name("metal") available_devices.extend(metal_devices) diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index b09a0a8a..ead65d94 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -143,6 +143,7 @@ def chat( global model_vmfb_key global vicuna_model + device_id = None model_name, model_path = list(map(str.strip, model.split("=>"))) if "cuda" in device: device = "cuda" @@ -151,6 +152,7 @@ def chat( elif "task" in device: device = "cpu-task" elif "vulkan" in device: + device_id = int(device.split("://")[1]) device = "vulkan" elif "rocm" in device: device = "rocm" @@ -169,10 +171,45 @@ def chat( # get iree flags that need to be overridden, from commandline args _extra_args = [] # vulkan target triple - if args.iree_vulkan_target_triple != "": + vulkan_target_triple = args.iree_vulkan_target_triple + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + get_vulkan_target_triple, + ) + + if device == "vulkan": + vulkaninfo_list = get_all_vulkan_devices() + if vulkan_target_triple == "": + # We already have the device_id extracted via WebUI, so we directly use + # that to find the target triple. + vulkan_target_triple = get_vulkan_target_triple( + vulkaninfo_list[device_id] + ) _extra_args.append( - f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" + f"-iree-vulkan-target-triple={vulkan_target_triple}" ) + if "rdna" in vulkan_target_triple: + flags_to_add = [ + "--iree-spirv-index-bits=64", + ] + _extra_args = _extra_args + flags_to_add + + if device_id is None: + id = 0 + for device in vulkaninfo_list: + target_triple = get_vulkan_target_triple( + vulkaninfo_list[id] + ) + if target_triple == vulkan_target_triple: + device_id = id + break + id += 1 + + assert ( + device_id + ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" + + print(f"Will use target triple : {vulkan_target_triple}") if model_name == "vicuna4": vicuna_model = ShardedVicuna( @@ -196,6 +233,7 @@ def chat( download_vmfb=download_vmfb, load_mlir_from_shark_tank=True, extra_args_cmd=_extra_args, + device_id=device_id, ) prompt = create_prompt(model_name, history) @@ -254,6 +292,7 @@ def llm_chat_api(InputData: dict): UnshardedVicuna, ) + device_id = None if vicuna_model == 0: if "cuda" in device: device = "cuda" @@ -262,6 +301,7 @@ def llm_chat_api(InputData: dict): elif "task" in device: device = "cpu-task" elif "vulkan" in device: + device_id = int(device.split("://")[1]) device = "vulkan" else: print("unrecognized device") @@ -274,6 +314,7 @@ def llm_chat_api(InputData: dict): max_num_tokens=max_toks, download_vmfb=True, load_mlir_from_shark_tank=True, + device_id=device_id, ) # TODO: add role dict for different models diff --git a/shark/iree_utils/vulkan_utils.py b/shark/iree_utils/vulkan_utils.py index e65f2ad9..fd494a29 100644 --- a/shark/iree_utils/vulkan_utils.py +++ b/shark/iree_utils/vulkan_utils.py @@ -24,10 +24,16 @@ from shark.parser import shark_args @functools.cache -def get_vulkan_device_name(device_num=0): +def get_all_vulkan_devices(): vulkaninfo_dump, _ = run_cmd("vulkaninfo") vulkaninfo_dump = vulkaninfo_dump.split(linesep) vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s] + return vulkaninfo_list + + +@functools.cache +def get_vulkan_device_name(device_num=0): + vulkaninfo_list = get_all_vulkan_devices() if len(vulkaninfo_list) == 0: raise ValueError("No device name found in VulkanInfo!") if len(vulkaninfo_list) > 1: