[ui] Add UI for sharding

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-12-05 21:20:16 +05:30
parent 1a723645fb
commit c74b55f24e

View File

@@ -141,7 +141,9 @@ def chat(
prompt_prefix,
history,
model,
device,
backend,
devices,
sharded,
precision,
download_vmfb,
config_file,
@@ -153,7 +155,8 @@ def chat(
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
device, device_id = clean_device_info(device)
device, device_id = clean_device_info(devices[0])
no_of_devices = len(devices)
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
@@ -221,7 +224,7 @@ def chat(
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
if sharded:
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
@@ -230,6 +233,7 @@ def chat(
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
n_devices=no_of_devices,
)
else:
# if config_file is None:
@@ -385,6 +389,16 @@ def view_json_file(file_obj):
return content
filtered_devices = dict()
def change_backend(backend):
new_choices = gr.Dropdown(
choices=filtered_devices[backend], label=f"{backend} devices"
)
return new_choices
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model_choices = list(
@@ -401,15 +415,22 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
backend_list = ["cpu", "cuda", "vulkan", "rocm"]
for x in backend_list:
filtered_devices[x] = [y for y in supported_devices if x in y]
print(filtered_devices)
backend = gr.Radio(
label="backend",
value="cpu",
choices=backend_list,
)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
label="cpu devices",
choices=filtered_devices["cpu"],
interactive=True,
allow_custom_value=True,
# multiselect=True,
multiselect=True,
)
precision = gr.Radio(
label="Precision",
@@ -433,6 +454,11 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
value=False,
interactive=True,
)
sharded = gr.Checkbox(
label="Shard Model",
value=False,
interactive=True,
)
with gr.Row(visible=False):
with gr.Group():
@@ -460,6 +486,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
backend.change(
fn=change_backend,
inputs=[backend],
outputs=[device],
show_progress=False,
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
@@ -472,7 +505,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
prompt_prefix,
chatbot,
model,
backend,
device,
sharded,
precision,
download_vmfb,
config_file,
@@ -493,7 +528,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
prompt_prefix,
chatbot,
model,
backend,
device,
sharded,
precision,
download_vmfb,
config_file,