diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index e3804f4b..851fdc26 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -20,24 +20,24 @@ if args.clear_all: clear_all() -def launch_app(address): - from tkinter import Tk - import webview - - window = Tk() - - # get screen width and height of display and make it more reasonably - # sized as we aren't making it full-screen or maximized - width = int(window.winfo_screenwidth() * 0.81) - height = int(window.winfo_screenheight() * 0.91) - webview.create_window( - "SHARK AI Studio", - url=address, - width=width, - height=height, - text_select=True, - ) - webview.start(private_mode=False) +# def launch_app(address): +# from tkinter import Tk +# import webview +# +# window = Tk() +# +# # get screen width and height of display and make it more reasonably +# # sized as we aren't making it full-screen or maximized +# width = int(window.winfo_screenwidth() * 0.81) +# height = int(window.winfo_screenheight() * 0.91) +# webview.create_window( +# "SHARK AI Studio", +# url=address, +# width=width, +# height=height, +# text_select=True, +# ) +# webview.start(private_mode=False) if __name__ == "__main__": @@ -424,11 +424,11 @@ if __name__ == "__main__": ) sd_web.queue() - if args.ui == "app": - t = Process( - target=launch_app, args=[f"http://localhost:{args.server_port}"] - ) - t.start() + # if args.ui == "app": + # t = Process( + # target=launch_app, args=[f"http://localhost:{args.server_port}"] + # ) + # t.start() sd_web.launch( share=args.share, inbrowser=args.ui == "web", diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 289f7e8b..9daeb060 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -143,10 +143,11 @@ def chat( curr_system_message, history, model, - devices, + device, precision, config_file, cli=True, + progress=gr.Progress(), ): global past_key_values @@ -166,7 +167,6 @@ def chat( from apps.stable_diffusion.src import args if vicuna_model == 0: - device = devices[0] if "cuda" in device: device = "cuda" elif "sync" in device: @@ -189,32 +189,34 @@ def chat( compressed=True, ) else: - if len(devices) == 1 and config_file is None: - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - hf_auth_token=args.hf_auth_token, - device=device, - precision=precision, - max_num_tokens=max_toks, - ) - else: - if config_file is not None: - config_file = open(config_file) - config_json = json.load(config_file) - config_file.close() - else: - config_json = get_default_config() - vicuna_model = ShardedVicuna( - model_name, - device=device, - precision=precision, - config_json=config_json, - ) + # if config_file is None: + vicuna_model = UnshardedVicuna( + model_name, + hf_model_path=model_path, + hf_auth_token=args.hf_auth_token, + device=device, + precision=precision, + max_num_tokens=max_toks, + ) + # else: + # if config_file is not None: + # config_file = open(config_file) + # config_json = json.load(config_file) + # config_file.close() + # else: + # config_json = get_default_config() + # vicuna_model = ShardedVicuna( + # model_name, + # device=device, + # precision=precision, + # config_json=config_json, + # ) prompt = create_prompt(model_name, history) - for partial_text in vicuna_model.generate(prompt, cli=cli): + for partial_text in progress.tqdm( + vicuna_model.generate(prompt, cli=cli), desc="generating response" + ): history[-1][1] = partial_text yield history @@ -365,7 +367,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat: ) model = gr.Dropdown( label="Select Model", - value=model_choices[0], + value=model_choices[4], choices=model_choices, ) supported_devices = available_devices @@ -381,24 +383,25 @@ with gr.Blocks(title="Chatbot") as stablelm_chat: else "Only CUDA Supported for now", choices=supported_devices, interactive=enabled, - multiselect=True, + # multiselect=True, ) precision = gr.Radio( label="Precision", - value="fp16", + value="int8", choices=[ "int4", "int8", "fp16", - "fp32", ], visible=True, ) - with gr.Row(): + with gr.Row(visible=False): with gr.Group(): - config_file = gr.File(label="Upload sharding configuration") - json_view_button = gr.Button("View as JSON") - json_view = gr.JSON(interactive=True) + config_file = gr.File( + label="Upload sharding configuration", visible=False + ) + json_view_button = gr.Button(label="View as JSON", visible=False) + json_view = gr.JSON(interactive=True, visible=False) json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] )