mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
[cli] Fix chatbot cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -1682,6 +1682,68 @@ class UnshardedVicuna(VicunaBase):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"StableLM": (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
"developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse "
|
||||
"to do anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also "
|
||||
"able to write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that "
|
||||
"could harm a human."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna4": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna1p3": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"codegen": "",
|
||||
}
|
||||
|
||||
def create_prompt(model_name, history):
|
||||
global start_message
|
||||
system_message = start_message[model_name]
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
@@ -1745,10 +1807,7 @@ if __name__ == "__main__":
|
||||
answer to a question, please don't share false information."""
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
|
||||
|
||||
history = []
|
||||
set_vicuna_model(vic)
|
||||
|
||||
model_list = {
|
||||
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
|
||||
@@ -1759,13 +1818,8 @@ if __name__ == "__main__":
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
history.append([user_prompt, ""])
|
||||
chat_history, msg = chat(
|
||||
system_message,
|
||||
history,
|
||||
model=model_list[args.model_name],
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
config_file=None,
|
||||
cli=True,
|
||||
)
|
||||
history = list(chat_history)[0]
|
||||
prompt = create_prompt(model_name, history)
|
||||
for text, msg in vic.generate(prompt, cli=True):
|
||||
if "formatted" in msg:
|
||||
print("Response:",text)
|
||||
history[-1][1] = text
|
||||
|
||||
Reference in New Issue
Block a user