mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
fix cli for vicuna (#1666)
This commit is contained in:
@@ -1564,21 +1564,15 @@ if __name__ == "__main__":
|
||||
config_json=config_json,
|
||||
weight_group_size=args.weight_group_size,
|
||||
)
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
|
||||
history = []
|
||||
set_vicuna_model(vic)
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
prompt_history = (
|
||||
prompt_history + "USER:\n" + user_prompt + prologue_prompt
|
||||
)
|
||||
prompt = prompt_history.strip()
|
||||
res_str = vic.generate(prompt, cli=True)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
"\n-----\nAssistant: Here's the complete formatted reply:\n",
|
||||
res_str,
|
||||
)
|
||||
prompt_history += f"\n{res_str}\n"
|
||||
history.append([user_prompt,""])
|
||||
history = list(chat(system_message, history, model="vicuna=>TheBloke/vicuna-7B-1.1-HF", device=args.device, precision=args.precision, cli=args.cli))[0]
|
||||
|
||||
|
||||
@@ -74,14 +74,17 @@ def create_prompt(model_name, history):
|
||||
return msg
|
||||
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(curr_system_message, history, model, device, precision):
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
def set_vicuna_model(model):
|
||||
global vicuna_model
|
||||
vicuna_model = model
|
||||
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
global past_key_values
|
||||
|
||||
global vicuna_model
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
print(f"In chat for {model_name}")
|
||||
|
||||
if model_name in ["vicuna", "vicuna1p3", "codegen"]:
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
@@ -109,9 +112,8 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
max_num_tokens=max_toks,
|
||||
)
|
||||
prompt = create_prompt(model_name, history)
|
||||
print("prompt = ", prompt)
|
||||
|
||||
for partial_text in vicuna_model.generate(prompt):
|
||||
for partial_text in vicuna_model.generate(prompt, cli=cli):
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
|
||||
@@ -140,7 +142,7 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
# print(new_text)
|
||||
print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to clean up the message textbox and the updated
|
||||
|
||||
Reference in New Issue
Block a user