[SD] Yield 2 tokens at a time in vicuna

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-05-11 22:58:58 +05:30
parent 649f39408b
commit e0cc2871bb

View File

@@ -1,6 +1,7 @@
import sys
import warnings
import gradio as gr
import time
warnings.filterwarnings("ignore")
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
@@ -634,6 +635,9 @@ def chat(curr_system_message, history):
tokens = input_ids
new_sentence = []
max_response_len = 1000
partial_sentence = []
partial_text = ""
start_time = time.time()
for iteration in range(max_response_len):
original_input_ids = input_ids
input_id_len = len(input_ids)
@@ -652,9 +656,22 @@ def chat(curr_system_message, history):
if new_token == 2:
break
new_sentence += [new_token]
partial_sentence += [new_token]
if iteration > 0 and iteration % 2 == 0:
new_text = tokenizer.decode(partial_sentence)
partial_sentence = []
print(new_text, " ")
partial_text += new_text + " "
history[-1][1] = partial_text
yield history
tokens.append(new_token)
original_input_ids.append(new_token)
input_ids = [new_token]
end_time = time.time()
print(
f"Total time taken to generated response is {end_time-start_time} seconds"
)
for i in range(len(tokens)):
if type(tokens[i]) != int: