mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[chatbot] Fix switching parameters in chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -139,6 +139,9 @@ def get_default_config():
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
model_vmfb_key = ""
|
||||
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(
|
||||
curr_system_message,
|
||||
@@ -151,10 +154,22 @@ def chat(
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
|
||||
global vicuna_model
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
@@ -167,18 +182,8 @@ def chat(
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
@@ -254,7 +259,8 @@ def chat(
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if sharkModel == 0:
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
model_name
|
||||
|
||||
Reference in New Issue
Block a user