mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Add codegen support in UI + lint
This commit is contained in:
committed by
Phaneesh Barwaria
parent
6e8dbf72bd
commit
3b8f7cc231
@@ -980,7 +980,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensor="pt",
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
|
||||
@@ -20,6 +20,7 @@ vicuna_model = 0
|
||||
past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
@@ -48,6 +49,7 @@ start_message = {
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"codegen": "",
|
||||
}
|
||||
|
||||
|
||||
@@ -55,9 +57,16 @@ def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if model_name in ["StableLM", "vicuna", "vicuna1p3"]:
|
||||
conversation = "".join(["".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]]) for item in history])
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
else:
|
||||
conversation = "".join(["".join([item[0], item[1]]) for item in history])
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
@@ -72,7 +81,7 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
print(f"In chat for {model_name}")
|
||||
|
||||
if model_name in ["vicuna", "vicuna1p3"]:
|
||||
if model_name in ["vicuna", "vicuna1p3", "codegen"]:
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
@@ -88,11 +97,14 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
)
|
||||
prompt = create_prompt(model_name, history)
|
||||
print("prompt = ", prompt)
|
||||
@@ -111,7 +123,9 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
|
||||
if sharkModel == 0:
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(model_name) # pass elements from UI as required
|
||||
shark_slm = SharkStableLM(
|
||||
model_name
|
||||
) # pass elements from UI as required
|
||||
|
||||
# Construct the input message string for the model by concatenating the
|
||||
# current system message and conversation history
|
||||
@@ -135,7 +149,9 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
|
||||
with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
with gr.Row():
|
||||
model_choices = list(map(lambda x: f"{x[0]: <10} => {x[1]}", model_map))
|
||||
model_choices = list(
|
||||
map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items())
|
||||
)
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
@@ -149,7 +165,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
print(supported_devices)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0] if enabled else "Only CUDA Supported for now",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
)
|
||||
@@ -179,15 +197,21 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
system_msg = gr.Textbox(start_message, label="System Message", interactive=False, visible=False)
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
|
||||
@@ -30,6 +30,7 @@ tk
|
||||
pywebview
|
||||
sentencepiece
|
||||
py-cpuinfo
|
||||
tiktoken # for codegen
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
Reference in New Issue
Block a user