fix cli for vicuna (#1666)

This commit is contained in:
Daniel Garvey
2023-07-18 12:03:40 -05:00
committed by GitHub
parent b0136593df
commit 8c317e4809
2 changed files with 17 additions and 21 deletions

View File

@@ -1564,21 +1564,15 @@ if __name__ == "__main__":
config_json=config_json,
weight_group_size=args.weight_group_size,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"
from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
history = []
set_vicuna_model(vic)
while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
prompt_history = (
prompt_history + "USER:\n" + user_prompt + prologue_prompt
)
prompt = prompt_history.strip()
res_str = vic.generate(prompt, cli=True)
torch.cuda.empty_cache()
gc.collect()
print(
"\n-----\nAssistant: Here's the complete formatted reply:\n",
res_str,
)
prompt_history += f"\n{res_str}\n"
history.append([user_prompt,""])
history = list(chat(system_message, history, model="vicuna=>TheBloke/vicuna-7B-1.1-HF", device=args.device, precision=args.precision, cli=args.cli))[0]

View File

@@ -74,14 +74,17 @@ def create_prompt(model_name, history):
return msg
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision):
global sharded_model
global past_key_values
def set_vicuna_model(model):
global vicuna_model
vicuna_model = model
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision, cli=True):
global past_key_values
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
print(f"In chat for {model_name}")
if model_name in ["vicuna", "vicuna1p3", "codegen"]:
from apps.language_models.scripts.vicuna import (
@@ -109,9 +112,8 @@ def chat(curr_system_message, history, model, device, precision):
max_num_tokens=max_toks,
)
prompt = create_prompt(model_name, history)
print("prompt = ", prompt)
for partial_text in vicuna_model.generate(prompt):
for partial_text in vicuna_model.generate(prompt, cli=cli):
history[-1][1] = partial_text
yield history
@@ -140,7 +142,7 @@ def chat(curr_system_message, history, model, device, precision):
partial_text = ""
for new_text in words_list:
# print(new_text)
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated