mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
[chat] Update chatbot ui
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -20,24 +20,24 @@ if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
|
||||
def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
# def launch_app(address):
|
||||
# from tkinter import Tk
|
||||
# import webview
|
||||
#
|
||||
# window = Tk()
|
||||
#
|
||||
# # get screen width and height of display and make it more reasonably
|
||||
# # sized as we aren't making it full-screen or maximized
|
||||
# width = int(window.winfo_screenwidth() * 0.81)
|
||||
# height = int(window.winfo_screenheight() * 0.91)
|
||||
# webview.create_window(
|
||||
# "SHARK AI Studio",
|
||||
# url=address,
|
||||
# width=width,
|
||||
# height=height,
|
||||
# text_select=True,
|
||||
# )
|
||||
# webview.start(private_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -424,11 +424,11 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
sd_web.queue()
|
||||
if args.ui == "app":
|
||||
t = Process(
|
||||
target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
)
|
||||
t.start()
|
||||
# if args.ui == "app":
|
||||
# t = Process(
|
||||
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
# )
|
||||
# t.start()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=args.ui == "web",
|
||||
|
||||
@@ -143,10 +143,11 @@ def chat(
|
||||
curr_system_message,
|
||||
history,
|
||||
model,
|
||||
devices,
|
||||
device,
|
||||
precision,
|
||||
config_file,
|
||||
cli=True,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
|
||||
@@ -166,7 +167,6 @@ def chat(
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
device = devices[0]
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
@@ -189,32 +189,34 @@ def chat(
|
||||
compressed=True,
|
||||
)
|
||||
else:
|
||||
if len(devices) == 1 and config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
)
|
||||
else:
|
||||
if config_file is not None:
|
||||
config_file = open(config_file)
|
||||
config_json = json.load(config_file)
|
||||
config_file.close()
|
||||
else:
|
||||
config_json = get_default_config()
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
device=device,
|
||||
precision=precision,
|
||||
config_json=config_json,
|
||||
)
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
)
|
||||
# else:
|
||||
# if config_file is not None:
|
||||
# config_file = open(config_file)
|
||||
# config_json = json.load(config_file)
|
||||
# config_file.close()
|
||||
# else:
|
||||
# config_json = get_default_config()
|
||||
# vicuna_model = ShardedVicuna(
|
||||
# model_name,
|
||||
# device=device,
|
||||
# precision=precision,
|
||||
# config_json=config_json,
|
||||
# )
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
|
||||
for partial_text in vicuna_model.generate(prompt, cli=cli):
|
||||
for partial_text in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli), desc="generating response"
|
||||
):
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
|
||||
@@ -365,7 +367,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
)
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
value=model_choices[4],
|
||||
choices=model_choices,
|
||||
)
|
||||
supported_devices = available_devices
|
||||
@@ -381,24 +383,25 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
multiselect=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
value="int8",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(label="Upload sharding configuration")
|
||||
json_view_button = gr.Button("View as JSON")
|
||||
json_view = gr.JSON(interactive=True)
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user