mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
[ui] Add UI for sharding
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -141,7 +141,9 @@ def chat(
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
backend,
|
||||
devices,
|
||||
sharded,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
@@ -153,7 +155,8 @@ def chat(
|
||||
global vicuna_model
|
||||
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
device, device_id = clean_device_info(device)
|
||||
device, device_id = clean_device_info(devices[0])
|
||||
no_of_devices = len(devices)
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
@@ -221,7 +224,7 @@ def chat(
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
if sharded:
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
@@ -230,6 +233,7 @@ def chat(
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
n_devices=no_of_devices,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
@@ -385,6 +389,16 @@ def view_json_file(file_obj):
|
||||
return content
|
||||
|
||||
|
||||
filtered_devices = dict()
|
||||
|
||||
|
||||
def change_backend(backend):
|
||||
new_choices = gr.Dropdown(
|
||||
choices=filtered_devices[backend], label=f"{backend} devices"
|
||||
)
|
||||
return new_choices
|
||||
|
||||
|
||||
with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
with gr.Row():
|
||||
model_choices = list(
|
||||
@@ -401,15 +415,22 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
backend_list = ["cpu", "cuda", "vulkan", "rocm"]
|
||||
for x in backend_list:
|
||||
filtered_devices[x] = [y for y in supported_devices if x in y]
|
||||
print(filtered_devices)
|
||||
|
||||
backend = gr.Radio(
|
||||
label="backend",
|
||||
value="cpu",
|
||||
choices=backend_list,
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
label="cpu devices",
|
||||
choices=filtered_devices["cpu"],
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
# multiselect=True,
|
||||
multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
@@ -433,6 +454,11 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
sharded = gr.Checkbox(
|
||||
label="Shard Model",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
@@ -460,6 +486,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
|
||||
backend.change(
|
||||
fn=change_backend,
|
||||
inputs=[backend],
|
||||
outputs=[device],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
@@ -472,7 +505,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
backend,
|
||||
device,
|
||||
sharded,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
@@ -493,7 +528,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
backend,
|
||||
device,
|
||||
sharded,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
|
||||
Reference in New Issue
Block a user