[chatbot] Fix chatbot cli and webview warning

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-08-17 19:17:59 +05:30
parent 32eb78f0f9
commit 9ae8bc921e
3 changed files with 12 additions and 13 deletions

View File

@@ -56,7 +56,7 @@ parser = argparse.ArgumentParser(
description="runs a vicuna model",
)
parser.add_argument(
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
"--precision", "-p", default="int8", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
@@ -408,7 +408,7 @@ class VicunaBase(SharkLLMBase):
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
else:
print(len(output))
# print(len(output))
_logits = torch.tensor(output[0])
_past_key_values = torch.tensor(output[1:])
_token = torch.argmax(_logits[:, -1, :], dim=1)
@@ -1624,7 +1624,7 @@ class UnshardedVicuna(VicunaBase):
)
return res_str
def generate(self, prompt, cli=True):
def generate(self, prompt, cli):
# TODO: refactor for cleaner integration
if self.shark_model is None:
self.compile()
@@ -1632,7 +1632,7 @@ class UnshardedVicuna(VicunaBase):
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}
generated_token_op = self.generate_new_token(
params=params, sharded=False, cli=False
params=params, sharded=False, cli=cli
)
token = generated_token_op["token"]
@@ -1655,7 +1655,7 @@ class UnshardedVicuna(VicunaBase):
}
generated_token_op = self.generate_new_token(
params=params, sharded=False
params=params, sharded=False, cli=cli
)
token = generated_token_op["token"]
@@ -1759,14 +1759,13 @@ if __name__ == "__main__":
# TODO: Add break condition from user input
user_prompt = input("User: ")
history.append([user_prompt, ""])
history = list(
chat(
chat_history, msg = chat(
system_message,
history,
model=model_list[args.model_name],
devices=args.device,
device=args.device,
precision=args.precision,
config_file=None,
cli=args.cli,
cli=True,
)
)[0]
history = list(chat_history)[0]

View File

@@ -37,7 +37,7 @@ def launch_app(address):
height=height,
text_select=True,
)
webview.start(private_mode=False)
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":

View File

@@ -150,7 +150,7 @@ def chat(
device,
precision,
config_file,
cli=True,
cli=False,
progress=gr.Progress(),
):
global past_key_values
@@ -235,7 +235,7 @@ def chat(
count = 0
start_time = time.time()
for text, msg in progress.tqdm(
vicuna_model.generate(prompt, cli=False),
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
count += 1