[SD] Integrate vicuna in the web (#1410)

This commit is contained in:
Gaurav Shukla
2023-05-11 00:00:22 +05:30
committed by GitHub
parent 517c670f82
commit fcb059aa38
2 changed files with 104 additions and 26 deletions

View File

@@ -15,6 +15,11 @@ from transformers import (
StoppingCriteriaList,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from apps.language_models.scripts.sharded_vicuna_fp32 import (
tokenizer,
SAMPLE_INPUT_LEN,
get_sharded_model,
)
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
@@ -34,9 +39,63 @@ attention_mask = torch.randint(3, (1, 256))
sharkModel = 0
sharded_model = 0
def chat(curr_system_message, history):
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
past_key_values = None
def chat(curr_system_message, history, model):
global sharded_model
global past_key_values
if "vicuna" in model:
curr_system_message = start_message_vicuna
if sharded_model == 0:
sharded_model = get_sharded_model()
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
prompt = messages.strip()
input_ids = tokenizer(prompt).input_ids
new_sentence = ""
for _ in range(200):
original_input_ids = input_ids
input_id_len = len(input_ids)
pad_len = SAMPLE_INPUT_LEN - input_id_len
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
attention_mask = torch.nn.functional.pad(
torch.tensor(attention_mask),
(0, pad_len),
mode="constant",
value=0,
)
if _ == 0:
output = sharded_model.forward(input_ids, is_first=True)
else:
output = sharded_model.forward(
input_ids, past_key_values=past_key_values, is_first=False
)
logits = output["logits"]
past_key_values = output["past_key_values"]
new_word = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=1))
if new_word == "</s>":
break
new_sentence += " " + new_word
history[-1][1] = new_sentence
yield history
next_token = torch.argmax(logits[:, input_id_len - 1, :], dim=1)
original_input_ids.append(next_token)
input_ids = [next_token]
print(new_sentence)
return history
global sharkModel
print("In chat")
if sharkModel == 0:
@@ -95,12 +154,15 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model = gr.Dropdown(
label="Select Model",
value="stabilityai/stablelm-tuned-alpha-3b",
choices=["stabilityai/stablelm-tuned-alpha-3b"],
value="TheBloke/vicuna-7B-1.1-HF",
choices=[
"stabilityai/stablelm-tuned-alpha-3b",
"TheBloke/vicuna-7B-1.1-HF",
],
)
device_value = None
for d in available_devices:
if "cuda" in d:
if "vulkan" in d:
device_value = d
break
@@ -130,12 +192,18 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True
fn=chat,
inputs=[system_msg, chatbot, model],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True
fn=chat,
inputs=[system_msg, chatbot, model],
outputs=[chatbot],
queue=True,
)
stop.click(
fn=None,