enable non default rocm device selection for webui

This commit is contained in:
PhaneeshB
2023-11-10 20:46:08 +05:30
committed by Phaneesh Barwaria
parent 91df5f0613
commit 392bade0bf
3 changed files with 39 additions and 39 deletions

View File

@@ -1259,6 +1259,7 @@ class UnshardedVicuna(VicunaBase):
max_num_tokens=512,
min_num_tokens=0,
device="cpu",
device_id=None,
vulkan_target_triple="",
precision="int8",
vicuna_mlir_path=None,
@@ -1269,7 +1270,6 @@ class UnshardedVicuna(VicunaBase):
download_vmfb=False,
cache_vicunas=False,
extra_args_cmd=[],
device_id=None,
debug=False,
) -> None:
super().__init__(
@@ -1288,9 +1288,7 @@ class UnshardedVicuna(VicunaBase):
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.min_num_tokens = min_num_tokens
self.device = device
self.vulkan_target_triple = vulkan_target_triple
self.device_id = device_id
self.precision = precision
self.download_vmfb = download_vmfb
self.vicuna_vmfb_path = vicuna_vmfb_path
@@ -1299,12 +1297,22 @@ class UnshardedVicuna(VicunaBase):
self.low_device_memory = low_device_memory
self.weight_group_size = weight_group_size
self.debug = debug
# Sanity check for device, device_id pair
if "://" in device:
if device_id is not None:
print("[ERR] can't have both full device path and a device id.\n"
f"Device : {device} | device_id : {device_id}\n"
"proceeding with given Device ignoring device_id")
self.device, self.device_id = device.split("://")
else:
self.device, self.device_id = device, device_id
if self.vicuna_mlir_path == None:
self.vicuna_mlir_path = self.get_model_path()
if self.vicuna_vmfb_path == None:
self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
self.tokenizer = self.get_tokenizer()
self.cache_vicunas = cache_vicunas
self.compile()
def get_model_path(self, suffix="mlir"):
@@ -1752,9 +1760,8 @@ class UnshardedVicuna(VicunaBase):
)
del first_module, second_module
print(self.device)
if "rocm" in self.device:
self.device = "rocm"
print(f"Compiling for device : {self.device}"
f"{'://' + str(self.device_id) if self.device_id is not None else ''}")
shark_module = SharkInference(
mlir_module=combined_module,
device=self.device,

View File

@@ -132,6 +132,27 @@ def get_default_config():
c.split_into_layers()
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by LLM pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
device_id = int(device_id) # using device index in webui
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
model_vmfb_key = ""
@@ -151,24 +172,8 @@ def chat(
global model_vmfb_key
global vicuna_model
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
device = device if "=>" not in device else device.split("=>")[1].strip()
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
elif "metal" in device:
device = "metal"
else:
print("unrecognized device")
device, device_id = clean_device_info(device)
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
@@ -325,19 +330,7 @@ def llm_chat_api(InputData: dict):
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "metal" in device:
device = "metal"
else:
print("unrecognized device")
device, device_id = clean_device_info(device)
vicuna_model = UnshardedVicuna(
model_name,

View File

@@ -34,9 +34,9 @@ def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
if device_uri[0] not in ["vulkan", "rocm"]:
print(
f"Specific device selection only supported for vulkan now."
f"Specific device selection only supported for vulkan and rocm."
f"Proceeding with {device} as device."
)
# device_uri can be device_num or device_path.
@@ -83,7 +83,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(extra_args=extra_args)
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []