Add codegen support in UI + lint

This commit is contained in:
PhaneeshB
2023-07-11 19:22:34 +05:30
committed by Phaneesh Barwaria
parent 6e8dbf72bd
commit 3b8f7cc231
3 changed files with 35 additions and 10 deletions

View File

@@ -980,7 +980,7 @@ class UnshardedVicuna(SharkLLMBase):
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = self.tokenizer(
compilation_prompt,
return_tensor="pt",
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids

View File

@@ -20,6 +20,7 @@ vicuna_model = 0
past_key_values = None
model_map = {
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
@@ -48,6 +49,7 @@ start_message = {
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"codegen": "",
}
@@ -55,9 +57,16 @@ def create_prompt(model_name, history):
system_message = start_message[model_name]
if model_name in ["StableLM", "vicuna", "vicuna1p3"]:
conversation = "".join(["".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]]) for item in history])
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
else:
conversation = "".join(["".join([item[0], item[1]]) for item in history])
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
@@ -72,7 +81,7 @@ def chat(curr_system_message, history, model, device, precision):
model_name, model_path = list(map(str.strip, model.split("=>")))
print(f"In chat for {model_name}")
if model_name in ["vicuna", "vicuna1p3"]:
if model_name in ["vicuna", "vicuna1p3", "codegen"]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
@@ -88,11 +97,14 @@ def chat(curr_system_message, history, model, device, precision):
device = "vulkan"
else:
print("unrecognized device")
max_toks = 128 if model_name == "codegen" else 512
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
prompt = create_prompt(model_name, history)
print("prompt = ", prompt)
@@ -111,7 +123,9 @@ def chat(curr_system_message, history, model, device, precision):
if sharkModel == 0:
# max_new_tokens=512
shark_slm = SharkStableLM(model_name) # pass elements from UI as required
shark_slm = SharkStableLM(
model_name
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the
# current system message and conversation history
@@ -135,7 +149,9 @@ def chat(curr_system_message, history, model, device, precision):
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model_choices = list(map(lambda x: f"{x[0]: <10} => {x[1]}", model_map))
model_choices = list(
map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items())
)
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
@@ -149,7 +165,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0] if enabled else "Only CUDA Supported for now",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
)
@@ -179,15 +197,21 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
system_msg = gr.Textbox(start_message, label="System Message", interactive=False, visible=False)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],

View File

@@ -30,6 +30,7 @@ tk
pywebview
sentencepiece
py-cpuinfo
tiktoken # for codegen
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile