mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
[chatbot] Fix chatbot cli and webview warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -56,7 +56,7 @@ parser = argparse.ArgumentParser(
|
||||
description="runs a vicuna model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
|
||||
"--precision", "-p", default="int8", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
@@ -408,7 +408,7 @@ class VicunaBase(SharkLLMBase):
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
else:
|
||||
print(len(output))
|
||||
# print(len(output))
|
||||
_logits = torch.tensor(output[0])
|
||||
_past_key_values = torch.tensor(output[1:])
|
||||
_token = torch.argmax(_logits[:, -1, :], dim=1)
|
||||
@@ -1624,7 +1624,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
)
|
||||
return res_str
|
||||
|
||||
def generate(self, prompt, cli=True):
|
||||
def generate(self, prompt, cli):
|
||||
# TODO: refactor for cleaner integration
|
||||
if self.shark_model is None:
|
||||
self.compile()
|
||||
@@ -1632,7 +1632,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}
|
||||
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False, cli=False
|
||||
params=params, sharded=False, cli=cli
|
||||
)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
@@ -1655,7 +1655,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False
|
||||
params=params, sharded=False, cli=cli
|
||||
)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
@@ -1759,14 +1759,13 @@ if __name__ == "__main__":
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
history.append([user_prompt, ""])
|
||||
history = list(
|
||||
chat(
|
||||
chat_history, msg = chat(
|
||||
system_message,
|
||||
history,
|
||||
model=model_list[args.model_name],
|
||||
devices=args.device,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
config_file=None,
|
||||
cli=args.cli,
|
||||
cli=True,
|
||||
)
|
||||
)[0]
|
||||
history = list(chat_history)[0]
|
||||
|
||||
@@ -37,7 +37,7 @@ def launch_app(address):
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -150,7 +150,7 @@ def chat(
|
||||
device,
|
||||
precision,
|
||||
config_file,
|
||||
cli=True,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
@@ -235,7 +235,7 @@ def chat(
|
||||
count = 0
|
||||
start_time = time.time()
|
||||
for text, msg in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=False),
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
count += 1
|
||||
|
||||
Reference in New Issue
Block a user