mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD] Yield 2 tokens at a time in vicuna
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import sys
|
||||
import warnings
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
|
||||
@@ -634,6 +635,9 @@ def chat(curr_system_message, history):
|
||||
tokens = input_ids
|
||||
new_sentence = []
|
||||
max_response_len = 1000
|
||||
partial_sentence = []
|
||||
partial_text = ""
|
||||
start_time = time.time()
|
||||
for iteration in range(max_response_len):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
@@ -652,9 +656,22 @@ def chat(curr_system_message, history):
|
||||
if new_token == 2:
|
||||
break
|
||||
new_sentence += [new_token]
|
||||
partial_sentence += [new_token]
|
||||
if iteration > 0 and iteration % 2 == 0:
|
||||
new_text = tokenizer.decode(partial_sentence)
|
||||
partial_sentence = []
|
||||
print(new_text, " ")
|
||||
partial_text += new_text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
|
||||
tokens.append(new_token)
|
||||
original_input_ids.append(new_token)
|
||||
input_ids = [new_token]
|
||||
end_time = time.time()
|
||||
print(
|
||||
f"Total time taken to generated response is {end_time-start_time} seconds"
|
||||
)
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if type(tokens[i]) != int:
|
||||
|
||||
Reference in New Issue
Block a user