mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
[chatbot] Add tokens generated per second (#1753)
This commit is contained in:
@@ -362,7 +362,7 @@ class VicunaBase(SharkLLMBase):
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def generate_new_token(self, params, sharded=True):
|
||||
def generate_new_token(self, params, sharded=True, cli=True):
|
||||
is_first = params["is_first"]
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
@@ -415,7 +415,8 @@ class VicunaBase(SharkLLMBase):
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
if cli:
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
|
||||
return ret_dict
|
||||
|
||||
@@ -1628,14 +1629,14 @@ class UnshardedVicuna(VicunaBase):
|
||||
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}
|
||||
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False
|
||||
params=params, sharded=False, cli=False
|
||||
)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok
|
||||
yield detok, ""
|
||||
|
||||
res_tokens.append(token)
|
||||
if cli:
|
||||
@@ -1668,14 +1669,11 @@ class UnshardedVicuna(VicunaBase):
|
||||
else:
|
||||
if cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
|
||||
if len(res_tokens) % 3 == 0:
|
||||
part_str = self.decode_tokens(res_tokens)
|
||||
yield part_str
|
||||
yield detok, ""
|
||||
|
||||
res_str = self.decode_tokens(res_tokens)
|
||||
# print(f"[DEBUG] final output : \n{res_str}")
|
||||
yield res_str
|
||||
yield res_str, "formatted"
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
|
||||
@@ -8,6 +8,7 @@ from transformers import (
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -214,13 +215,27 @@ def chat(
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
|
||||
for partial_text in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli), desc="generating response"
|
||||
partial_text = ""
|
||||
count = 0
|
||||
start_time = time.time()
|
||||
for text, msg in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=False),
|
||||
desc="generating response",
|
||||
):
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
count += 1
|
||||
if "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
end_time = time.time()
|
||||
tokens_per_sec = count / (end_time - start_time)
|
||||
yield history, str(
|
||||
format(tokens_per_sec, ".2f")
|
||||
) + " tokens/sec"
|
||||
else:
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, ""
|
||||
|
||||
return history
|
||||
return history, ""
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
@@ -245,7 +260,6 @@ def chat(
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to clean up the message textbox and the updated
|
||||
@@ -375,7 +389,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
print(supported_devices)
|
||||
# print(supported_devices)
|
||||
devices = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
@@ -395,6 +409,8 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
@@ -429,7 +445,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
@@ -437,7 +453,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
Reference in New Issue
Block a user