fix cli for vicuna (#1666)

This commit is contained in:
Daniel Garvey
2023-07-18 12:03:40 -05:00
committed by GitHub
parent b0136593df
commit 8c317e4809
2 changed files with 17 additions and 21 deletions

View File

@@ -74,14 +74,17 @@ def create_prompt(model_name, history):
return msg
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision):
global sharded_model
global past_key_values
def set_vicuna_model(model):
global vicuna_model
vicuna_model = model
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision, cli=True):
global past_key_values
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
print(f"In chat for {model_name}")
if model_name in ["vicuna", "vicuna1p3", "codegen"]:
from apps.language_models.scripts.vicuna import (
@@ -109,9 +112,8 @@ def chat(curr_system_message, history, model, device, precision):
max_num_tokens=max_toks,
)
prompt = create_prompt(model_name, history)
print("prompt = ", prompt)
for partial_text in vicuna_model.generate(prompt):
for partial_text in vicuna_model.generate(prompt, cli=cli):
history[-1][1] = partial_text
yield history
@@ -140,7 +142,7 @@ def chat(curr_system_message, history, model, device, precision):
partial_text = ""
for new_text in words_list:
# print(new_text)
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated