mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD] Integrate vicuna in the web (#1410)
This commit is contained in:
@@ -15,6 +15,11 @@ from transformers import (
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from apps.language_models.scripts.sharded_vicuna_fp32 import (
|
||||
tokenizer,
|
||||
SAMPLE_INPUT_LEN,
|
||||
get_sharded_model,
|
||||
)
|
||||
|
||||
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
@@ -34,9 +39,63 @@ attention_mask = torch.randint(3, (1, 256))
|
||||
|
||||
|
||||
sharkModel = 0
|
||||
sharded_model = 0
|
||||
|
||||
|
||||
def chat(curr_system_message, history):
|
||||
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
past_key_values = None
|
||||
|
||||
|
||||
def chat(curr_system_message, history, model):
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
if "vicuna" in model:
|
||||
curr_system_message = start_message_vicuna
|
||||
if sharded_model == 0:
|
||||
sharded_model = get_sharded_model()
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
prompt = messages.strip()
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
new_sentence = ""
|
||||
for _ in range(200):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
pad_len = SAMPLE_INPUT_LEN - input_id_len
|
||||
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
attention_mask = torch.nn.functional.pad(
|
||||
torch.tensor(attention_mask),
|
||||
(0, pad_len),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
|
||||
if _ == 0:
|
||||
output = sharded_model.forward(input_ids, is_first=True)
|
||||
else:
|
||||
output = sharded_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=False
|
||||
)
|
||||
logits = output["logits"]
|
||||
past_key_values = output["past_key_values"]
|
||||
new_word = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=1))
|
||||
if new_word == "</s>":
|
||||
break
|
||||
new_sentence += " " + new_word
|
||||
history[-1][1] = new_sentence
|
||||
yield history
|
||||
next_token = torch.argmax(logits[:, input_id_len - 1, :], dim=1)
|
||||
original_input_ids.append(next_token)
|
||||
input_ids = [next_token]
|
||||
print(new_sentence)
|
||||
return history
|
||||
|
||||
global sharkModel
|
||||
print("In chat")
|
||||
if sharkModel == 0:
|
||||
@@ -95,12 +154,15 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
with gr.Row():
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value="stabilityai/stablelm-tuned-alpha-3b",
|
||||
choices=["stabilityai/stablelm-tuned-alpha-3b"],
|
||||
value="TheBloke/vicuna-7B-1.1-HF",
|
||||
choices=[
|
||||
"stabilityai/stablelm-tuned-alpha-3b",
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
device_value = None
|
||||
for d in available_devices:
|
||||
if "cuda" in d:
|
||||
if "vulkan" in d:
|
||||
device_value = d
|
||||
break
|
||||
|
||||
@@ -130,12 +192,18 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
|
||||
Reference in New Issue
Block a user