diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 6f81c846..5b13145a 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -141,7 +141,9 @@ def chat( prompt_prefix, history, model, - device, + backend, + devices, + sharded, precision, download_vmfb, config_file, @@ -153,7 +155,8 @@ def chat( global vicuna_model model_name, model_path = list(map(str.strip, model.split("=>"))) - device, device_id = clean_device_info(device) + device, device_id = clean_device_info(devices[0]) + no_of_devices = len(devices) from apps.language_models.scripts.vicuna import ShardedVicuna from apps.language_models.scripts.vicuna import UnshardedVicuna @@ -221,7 +224,7 @@ def chat( ) print(f"extra args = {_extra_args}") - if model_name == "vicuna4": + if sharded: vicuna_model = ShardedVicuna( model_name, hf_model_path=model_path, @@ -230,6 +233,7 @@ def chat( max_num_tokens=max_toks, compressed=True, extra_args_cmd=_extra_args, + n_devices=no_of_devices, ) else: # if config_file is None: @@ -385,6 +389,16 @@ def view_json_file(file_obj): return content +filtered_devices = dict() + + +def change_backend(backend): + new_choices = gr.Dropdown( + choices=filtered_devices[backend], label=f"{backend} devices" + ) + return new_choices + + with gr.Blocks(title="Chatbot") as stablelm_chat: with gr.Row(): model_choices = list( @@ -401,15 +415,22 @@ with gr.Blocks(title="Chatbot") as stablelm_chat: # show cpu-task device first in list for chatbot supported_devices = supported_devices[-1:] + supported_devices[:-1] supported_devices = [x for x in supported_devices if "sync" not in x] + backend_list = ["cpu", "cuda", "vulkan", "rocm"] + for x in backend_list: + filtered_devices[x] = [y for y in supported_devices if x in y] + print(filtered_devices) + + backend = gr.Radio( + label="backend", + value="cpu", + choices=backend_list, + ) device = gr.Dropdown( - label="Device", - value=supported_devices[0] - if enabled - else "Only CUDA Supported for now", - choices=supported_devices, - interactive=enabled, + label="cpu devices", + choices=filtered_devices["cpu"], + interactive=True, allow_custom_value=True, - # multiselect=True, + multiselect=True, ) precision = gr.Radio( label="Precision", @@ -433,6 +454,11 @@ with gr.Blocks(title="Chatbot") as stablelm_chat: value=False, interactive=True, ) + sharded = gr.Checkbox( + label="Shard Model", + value=False, + interactive=True, + ) with gr.Row(visible=False): with gr.Group(): @@ -460,6 +486,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat: stop = gr.Button("Stop", interactive=enabled) clear = gr.Button("Clear", interactive=enabled) + backend.change( + fn=change_backend, + inputs=[backend], + outputs=[device], + show_progress=False, + ) + submit_event = msg.submit( fn=user, inputs=[msg, chatbot], @@ -472,7 +505,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat: prompt_prefix, chatbot, model, + backend, device, + sharded, precision, download_vmfb, config_file, @@ -493,7 +528,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat: prompt_prefix, chatbot, model, + backend, device, + sharded, precision, download_vmfb, config_file,