[chatbot] Fix switching parameters in chatbot

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-08-16 21:08:44 +05:30
parent cb509343d9
commit 32eb78f0f9

View File

@@ -139,6 +139,9 @@ def get_default_config():
c.split_into_layers()
model_vmfb_key = ""
# TODO: Make chat reusable for UI and API
def chat(
curr_system_message,
@@ -151,10 +154,22 @@ def chat(
progress=gr.Progress(),
):
global past_key_values
global model_vmfb_key
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
if model_name in [
"vicuna",
"vicuna4",
@@ -167,18 +182,8 @@ def chat(
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
if new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
@@ -254,7 +259,8 @@ def chat(
SharkStableLM,
)
if sharkModel == 0:
if new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
# max_new_tokens=512
shark_slm = SharkStableLM(
model_name