[chatbot] Add tokens generated per second (#1753)

This commit is contained in:
Gaurav Shukla
2023-08-13 23:55:41 +05:30
committed by GitHub
parent 18801dcabc
commit 4dc9c59611
2 changed files with 32 additions and 18 deletions

View File

@@ -362,7 +362,7 @@ class VicunaBase(SharkLLMBase):
with open(output_name, "rb") as f:
return f.read()
def generate_new_token(self, params, sharded=True):
def generate_new_token(self, params, sharded=True, cli=True):
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
@@ -415,7 +415,8 @@ class VicunaBase(SharkLLMBase):
"past_key_values": _past_key_values,
}
print(f" token : {_token} | detok : {_detok}")
if cli:
print(f" token : {_token} | detok : {_detok}")
return ret_dict
@@ -1628,14 +1629,14 @@ class UnshardedVicuna(VicunaBase):
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}
generated_token_op = self.generate_new_token(
params=params, sharded=False
params=params, sharded=False, cli=False
)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]
yield detok
yield detok, ""
res_tokens.append(token)
if cli:
@@ -1668,14 +1669,11 @@ class UnshardedVicuna(VicunaBase):
else:
if cli:
print(f"{detok}", end=" ", flush=True)
if len(res_tokens) % 3 == 0:
part_str = self.decode_tokens(res_tokens)
yield part_str
yield detok, ""
res_str = self.decode_tokens(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
yield res_str
yield res_str, "formatted"
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.

View File

@@ -8,6 +8,7 @@ from transformers import (
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json
import time
def user(message, history):
@@ -214,13 +215,27 @@ def chat(
prompt = create_prompt(model_name, history)
for partial_text in progress.tqdm(
vicuna_model.generate(prompt, cli=cli), desc="generating response"
partial_text = ""
count = 0
start_time = time.time()
for text, msg in progress.tqdm(
vicuna_model.generate(prompt, cli=False),
desc="generating response",
):
history[-1][1] = partial_text
yield history
count += 1
if "formatted" in msg:
history[-1][1] = text
end_time = time.time()
tokens_per_sec = count / (end_time - start_time)
yield history, str(
format(tokens_per_sec, ".2f")
) + " tokens/sec"
else:
partial_text += text + " "
history[-1][1] = partial_text
yield history, ""
return history
return history, ""
# else Model is StableLM
global sharkModel
@@ -245,7 +260,6 @@ def chat(
partial_text = ""
for new_text in words_list:
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated
@@ -375,7 +389,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
# print(supported_devices)
devices = gr.Dropdown(
label="Device",
value=supported_devices[0]
@@ -395,6 +409,8 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
],
visible=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
@@ -429,7 +445,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
outputs=[chatbot, tokens_time],
queue=True,
)
submit_click_event = submit.click(
@@ -437,7 +453,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
outputs=[chatbot, tokens_time],
queue=True,
)
stop.click(