From 9ae8bc921e5606453c8bcb00445f736324e478cc Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Thu, 17 Aug 2023 19:17:59 +0530 Subject: [PATCH] [chatbot] Fix chatbot cli and webview warning Signed-Off-by: Gaurav Shukla --- apps/language_models/scripts/vicuna.py | 19 +++++++++---------- apps/stable_diffusion/web/index.py | 2 +- apps/stable_diffusion/web/ui/stablelm_ui.py | 4 ++-- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 486f3b81..39312741 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -56,7 +56,7 @@ parser = argparse.ArgumentParser( description="runs a vicuna model", ) parser.add_argument( - "--precision", "-p", default="fp32", help="fp32, fp16, int8, int4" + "--precision", "-p", default="int8", help="fp32, fp16, int8, int4" ) parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda") parser.add_argument( @@ -408,7 +408,7 @@ class VicunaBase(SharkLLMBase): _past_key_values = output["past_key_values"] _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) else: - print(len(output)) + # print(len(output)) _logits = torch.tensor(output[0]) _past_key_values = torch.tensor(output[1:]) _token = torch.argmax(_logits[:, -1, :], dim=1) @@ -1624,7 +1624,7 @@ class UnshardedVicuna(VicunaBase): ) return res_str - def generate(self, prompt, cli=True): + def generate(self, prompt, cli): # TODO: refactor for cleaner integration if self.shark_model is None: self.compile() @@ -1632,7 +1632,7 @@ class UnshardedVicuna(VicunaBase): params = {"prompt": prompt, "is_first": True, "fv": self.shark_model} generated_token_op = self.generate_new_token( - params=params, sharded=False, cli=False + params=params, sharded=False, cli=cli ) token = generated_token_op["token"] @@ -1655,7 +1655,7 @@ class UnshardedVicuna(VicunaBase): } generated_token_op = self.generate_new_token( - params=params, sharded=False + params=params, sharded=False, cli=cli ) token = generated_token_op["token"] @@ -1759,14 +1759,13 @@ if __name__ == "__main__": # TODO: Add break condition from user input user_prompt = input("User: ") history.append([user_prompt, ""]) - history = list( - chat( + chat_history, msg = chat( system_message, history, model=model_list[args.model_name], - devices=args.device, + device=args.device, precision=args.precision, config_file=None, - cli=args.cli, + cli=True, ) - )[0] + history = list(chat_history)[0] diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index e3804f4b..dcb1e570 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -37,7 +37,7 @@ def launch_app(address): height=height, text_select=True, ) - webview.start(private_mode=False) + webview.start(private_mode=False, storage_path=os.getcwd()) if __name__ == "__main__": diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 4264e2c7..2bee97a6 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -150,7 +150,7 @@ def chat( device, precision, config_file, - cli=True, + cli=False, progress=gr.Progress(), ): global past_key_values @@ -235,7 +235,7 @@ def chat( count = 0 start_time = time.time() for text, msg in progress.tqdm( - vicuna_model.generate(prompt, cli=False), + vicuna_model.generate(prompt, cli=cli), desc="generating response", ): count += 1