mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD] Fix vicuna response
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -621,13 +621,7 @@ def user(message, history):
|
||||
|
||||
def chat(curr_system_message, history):
|
||||
global sharded_model
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
past_key_values = None
|
||||
# user_prompt = input("User: ")
|
||||
# prompt_history = (
|
||||
# prompt_history + "USER:\n" + user_prompt + prologue_prompt
|
||||
# )
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
@@ -635,14 +629,11 @@ def chat(curr_system_message, history):
|
||||
]
|
||||
)
|
||||
print(messages)
|
||||
# prompt = prompt_history.strip()
|
||||
prompt = messages.strip()
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
tokens = input_ids
|
||||
# prompt = print("Robot:", end=" ")
|
||||
new_sentence = []
|
||||
max_response_len = 1000
|
||||
new_sentence_str = ""
|
||||
for iteration in range(max_response_len):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
@@ -660,10 +651,6 @@ def chat(curr_system_message, history):
|
||||
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
|
||||
if new_token == 2:
|
||||
break
|
||||
new_text = tokenizer.decode(new_token)
|
||||
new_sentence_str += new_text
|
||||
history[-1][1] = new_sentence_str
|
||||
yield history
|
||||
new_sentence += [new_token]
|
||||
tokens.append(new_token)
|
||||
original_input_ids.append(new_token)
|
||||
@@ -672,10 +659,10 @@ def chat(curr_system_message, history):
|
||||
for i in range(len(tokens)):
|
||||
if type(tokens[i]) != int:
|
||||
tokens[i] = int(tokens[i][0])
|
||||
new_sentence_str2 = tokenizer.decode(new_sentence)
|
||||
print(new_sentence_str2)
|
||||
prompt_history += f"\n{new_sentence_str}\n"
|
||||
return new_sentence_str2
|
||||
new_sentence_str = tokenizer.decode(new_sentence)
|
||||
print(new_sentence_str)
|
||||
history[-1][1] = new_sentence_str
|
||||
return history
|
||||
|
||||
|
||||
system_msg = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
|
||||
Reference in New Issue
Block a user