mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[vicuna] Add streaming of tokens (#1587)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -464,6 +464,7 @@ class Vicuna(SharkLLMBase):
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["pkv"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok
|
||||
|
||||
res.append(detok)
|
||||
res_tokens.append(token)
|
||||
@@ -505,6 +506,8 @@ class Vicuna(SharkLLMBase):
|
||||
res.append(detok)
|
||||
if cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
yield detok
|
||||
|
||||
if self.device == "cuda":
|
||||
del sec_vic, pkv, logits
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -65,16 +65,14 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
)
|
||||
prompt = messages.strip()
|
||||
print("prompt = ", prompt)
|
||||
sentence = vicuna_model.generate(prompt)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in sentence.split(" "):
|
||||
for new_text in vicuna_model.generate(prompt):
|
||||
# print(new_text)
|
||||
partial_text += new_text + " "
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
||||
yield history
|
||||
history[-1][1] = sentence
|
||||
return history
|
||||
|
||||
# else Model is StableLM
|
||||
@@ -124,13 +122,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
supported_devices = available_devices
|
||||
supported_devices = available_devices + ["AMD-AIE"]
|
||||
enabled = len(supported_devices) > 0
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
else "No devices supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user