mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
[chatbot] Add tokens generated per second (#1753)
This commit is contained in:
@@ -8,6 +8,7 @@ from transformers import (
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -214,13 +215,27 @@ def chat(
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
|
||||
for partial_text in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli), desc="generating response"
|
||||
partial_text = ""
|
||||
count = 0
|
||||
start_time = time.time()
|
||||
for text, msg in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=False),
|
||||
desc="generating response",
|
||||
):
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
count += 1
|
||||
if "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
end_time = time.time()
|
||||
tokens_per_sec = count / (end_time - start_time)
|
||||
yield history, str(
|
||||
format(tokens_per_sec, ".2f")
|
||||
) + " tokens/sec"
|
||||
else:
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, ""
|
||||
|
||||
return history
|
||||
return history, ""
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
@@ -245,7 +260,6 @@ def chat(
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to clean up the message textbox and the updated
|
||||
@@ -375,7 +389,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
print(supported_devices)
|
||||
# print(supported_devices)
|
||||
devices = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
@@ -395,6 +409,8 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
@@ -429,7 +445,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
@@ -437,7 +453,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
Reference in New Issue
Block a user