[chat] Update chatbot ui

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-08-11 21:33:01 +05:30
parent 3c577f7168
commit 18801dcabc
2 changed files with 59 additions and 56 deletions

View File

@@ -20,24 +20,24 @@ if args.clear_all:
clear_all()
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False)
# def launch_app(address):
# from tkinter import Tk
# import webview
#
# window = Tk()
#
# # get screen width and height of display and make it more reasonably
# # sized as we aren't making it full-screen or maximized
# width = int(window.winfo_screenwidth() * 0.81)
# height = int(window.winfo_screenheight() * 0.91)
# webview.create_window(
# "SHARK AI Studio",
# url=address,
# width=width,
# height=height,
# text_select=True,
# )
# webview.start(private_mode=False)
if __name__ == "__main__":
@@ -424,11 +424,11 @@ if __name__ == "__main__":
)
sd_web.queue()
if args.ui == "app":
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
sd_web.launch(
share=args.share,
inbrowser=args.ui == "web",

View File

@@ -143,10 +143,11 @@ def chat(
curr_system_message,
history,
model,
devices,
device,
precision,
config_file,
cli=True,
progress=gr.Progress(),
):
global past_key_values
@@ -166,7 +167,6 @@ def chat(
from apps.stable_diffusion.src import args
if vicuna_model == 0:
device = devices[0]
if "cuda" in device:
device = "cuda"
elif "sync" in device:
@@ -189,32 +189,34 @@ def chat(
compressed=True,
)
else:
if len(devices) == 1 and config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
else:
if config_file is not None:
config_file = open(config_file)
config_json = json.load(config_file)
config_file.close()
else:
config_json = get_default_config()
vicuna_model = ShardedVicuna(
model_name,
device=device,
precision=precision,
config_json=config_json,
)
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
# else:
# if config_file is not None:
# config_file = open(config_file)
# config_json = json.load(config_file)
# config_file.close()
# else:
# config_json = get_default_config()
# vicuna_model = ShardedVicuna(
# model_name,
# device=device,
# precision=precision,
# config_json=config_json,
# )
prompt = create_prompt(model_name, history)
for partial_text in vicuna_model.generate(prompt, cli=cli):
for partial_text in progress.tqdm(
vicuna_model.generate(prompt, cli=cli), desc="generating response"
):
history[-1][1] = partial_text
yield history
@@ -365,7 +367,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
)
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
value=model_choices[4],
choices=model_choices,
)
supported_devices = available_devices
@@ -381,24 +383,25 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
multiselect=True,
# multiselect=True,
)
precision = gr.Radio(
label="Precision",
value="fp16",
value="int8",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],
visible=True,
)
with gr.Row():
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration")
json_view_button = gr.Button("View as JSON")
json_view = gr.JSON(interactive=True)
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)