mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
enable non default rocm device selection for webui
This commit is contained in:
committed by
Phaneesh Barwaria
parent
91df5f0613
commit
392bade0bf
@@ -1259,6 +1259,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
max_num_tokens=512,
|
||||
min_num_tokens=0,
|
||||
device="cpu",
|
||||
device_id=None,
|
||||
vulkan_target_triple="",
|
||||
precision="int8",
|
||||
vicuna_mlir_path=None,
|
||||
@@ -1269,7 +1270,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
download_vmfb=False,
|
||||
cache_vicunas=False,
|
||||
extra_args_cmd=[],
|
||||
device_id=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@@ -1288,9 +1288,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.min_num_tokens = min_num_tokens
|
||||
self.device = device
|
||||
self.vulkan_target_triple = vulkan_target_triple
|
||||
self.device_id = device_id
|
||||
self.precision = precision
|
||||
self.download_vmfb = download_vmfb
|
||||
self.vicuna_vmfb_path = vicuna_vmfb_path
|
||||
@@ -1299,12 +1297,22 @@ class UnshardedVicuna(VicunaBase):
|
||||
self.low_device_memory = low_device_memory
|
||||
self.weight_group_size = weight_group_size
|
||||
self.debug = debug
|
||||
# Sanity check for device, device_id pair
|
||||
if "://" in device:
|
||||
if device_id is not None:
|
||||
print("[ERR] can't have both full device path and a device id.\n"
|
||||
f"Device : {device} | device_id : {device_id}\n"
|
||||
"proceeding with given Device ignoring device_id")
|
||||
self.device, self.device_id = device.split("://")
|
||||
else:
|
||||
self.device, self.device_id = device, device_id
|
||||
if self.vicuna_mlir_path == None:
|
||||
self.vicuna_mlir_path = self.get_model_path()
|
||||
if self.vicuna_vmfb_path == None:
|
||||
self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.cache_vicunas = cache_vicunas
|
||||
|
||||
self.compile()
|
||||
|
||||
def get_model_path(self, suffix="mlir"):
|
||||
@@ -1752,9 +1760,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
)
|
||||
del first_module, second_module
|
||||
|
||||
print(self.device)
|
||||
if "rocm" in self.device:
|
||||
self.device = "rocm"
|
||||
print(f"Compiling for device : {self.device}"
|
||||
f"{'://' + str(self.device_id) if self.device_id is not None else ''}")
|
||||
shark_module = SharkInference(
|
||||
mlir_module=combined_module,
|
||||
device=self.device,
|
||||
|
||||
@@ -132,6 +132,27 @@ def get_default_config():
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by LLM pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
device_id = int(device_id) # using device index in webui
|
||||
|
||||
if device not in ["rocm", "vulkan"]:
|
||||
device_id = None
|
||||
|
||||
return device, device_id
|
||||
|
||||
|
||||
model_vmfb_key = ""
|
||||
|
||||
|
||||
@@ -151,24 +172,8 @@ def chat(
|
||||
global model_vmfb_key
|
||||
global vicuna_model
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
device = device if "=>" not in device else device.split("=>")[1].strip()
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
elif "metal" in device:
|
||||
device = "metal"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
device, device_id = clean_device_info(device)
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
@@ -325,19 +330,7 @@ def llm_chat_api(InputData: dict):
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "metal" in device:
|
||||
device = "metal"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
device, device_id = clean_device_info(device)
|
||||
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
|
||||
@@ -34,9 +34,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan"]:
|
||||
if device_uri[0] not in ["vulkan", "rocm"]:
|
||||
print(
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Specific device selection only supported for vulkan and rocm."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
# device_uri can be device_num or device_path.
|
||||
@@ -83,7 +83,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args(extra_args=extra_args)
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
return []
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user