[chatbot] Add tokens generated per second (#1753)

This commit is contained in:
Gaurav Shukla
2023-08-13 23:55:41 +05:30
committed by GitHub
parent 18801dcabc
commit 4dc9c59611
2 changed files with 32 additions and 18 deletions

View File

@@ -8,6 +8,7 @@ from transformers import (
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json
import time
def user(message, history):
@@ -214,13 +215,27 @@ def chat(
prompt = create_prompt(model_name, history)
for partial_text in progress.tqdm(
vicuna_model.generate(prompt, cli=cli), desc="generating response"
partial_text = ""
count = 0
start_time = time.time()
for text, msg in progress.tqdm(
vicuna_model.generate(prompt, cli=False),
desc="generating response",
):
history[-1][1] = partial_text
yield history
count += 1
if "formatted" in msg:
history[-1][1] = text
end_time = time.time()
tokens_per_sec = count / (end_time - start_time)
yield history, str(
format(tokens_per_sec, ".2f")
) + " tokens/sec"
else:
partial_text += text + " "
history[-1][1] = partial_text
yield history, ""
return history
return history, ""
# else Model is StableLM
global sharkModel
@@ -245,7 +260,6 @@ def chat(
partial_text = ""
for new_text in words_list:
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated
@@ -375,7 +389,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
# print(supported_devices)
devices = gr.Dropdown(
label="Device",
value=supported_devices[0]
@@ -395,6 +409,8 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
],
visible=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
@@ -429,7 +445,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
outputs=[chatbot, tokens_time],
queue=True,
)
submit_click_event = submit.click(
@@ -437,7 +453,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
outputs=[chatbot, tokens_time],
queue=True,
)
stop.click(