mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add StreamingLLM support to studio2 chat (#2060)
* Streaming LLM * Update precision and add gpu support * (studio2) Separate weights generation for quantization support * Adapt prompt changes to studio flow * Remove outdated flag from llm compile flags. * (studio2) use turbine vmfbRunner * tweaks to prompts * Update CPU path and llm api test. * Change device in test to cpu. * Fixes to runner, device names, vmfb mgmt * Use small test without external weights.
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
from turbine_models.custom_models import stateless_llama
|
||||
from turbine_models.model_runner import vmfbRunner
|
||||
from turbine_models.gen_external_params.gen_external_params import gen_external_params
|
||||
import time
|
||||
from shark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
load_vmfb_using_mmap,
|
||||
)
|
||||
from apps.shark_studio.web.utils.file_utils import get_resource_path
|
||||
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
|
||||
from apps.shark_studio.web.utils import get_resource_path
|
||||
import iree.runtime as ireert
|
||||
from itertools import chain
|
||||
import gc
|
||||
@@ -16,6 +15,7 @@ llm_model_map = {
|
||||
"llama2_7b": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
|
||||
"stop_token": 2,
|
||||
"max_tokens": 4096,
|
||||
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
|
||||
@@ -23,12 +23,34 @@ llm_model_map = {
|
||||
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
|
||||
"stop_token": 2,
|
||||
"max_tokens": 4096,
|
||||
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
|
||||
},
|
||||
"TinyPixel/small-llama2": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "TinyPixel/small-llama2",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
|
||||
"stop_token": 2,
|
||||
"max_tokens": 1024,
|
||||
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
|
||||
},
|
||||
}
|
||||
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
DEFAULT_CHAT_SYS_PROMPT = """<s>[INST] <<SYS>>
|
||||
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <</SYS>>\n\n
|
||||
"""
|
||||
|
||||
|
||||
def append_user_prompt(history, input_prompt):
|
||||
user_prompt = f"{B_INST} {input_prompt} {E_INST}"
|
||||
history += user_prompt
|
||||
return history
|
||||
|
||||
|
||||
class LanguageModel:
|
||||
def __init__(
|
||||
@@ -36,41 +58,85 @@ class LanguageModel:
|
||||
model_name,
|
||||
hf_auth_token=None,
|
||||
device=None,
|
||||
precision="fp32",
|
||||
quantization="int4",
|
||||
precision="",
|
||||
external_weights=None,
|
||||
use_system_prompt=True,
|
||||
streaming_llm=False,
|
||||
):
|
||||
print(llm_model_map[model_name])
|
||||
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
|
||||
self.tempfile_name = get_resource_path("llm.torch.tempfile")
|
||||
self.vmfb_name = get_resource_path("llm.vmfb.tempfile")
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.safe_name = self.hf_model_name.strip("/").replace("/", "_")
|
||||
self.max_tokens = llm_model_map[model_name]["max_tokens"]
|
||||
self.iree_module_dict = None
|
||||
self.device = device.split("=>")[-1].strip()
|
||||
self.backend = self.device.split("://")[0]
|
||||
self.driver = self.backend
|
||||
if "cpu" in device:
|
||||
self.device = "cpu"
|
||||
self.backend = "llvm-cpu"
|
||||
self.driver = "local-task"
|
||||
|
||||
print(f"Selected {self.backend} as IREE target backend.")
|
||||
self.precision = "f32" if "cpu" in device else "f16"
|
||||
self.quantization = quantization
|
||||
self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_")
|
||||
self.external_weight_file = None
|
||||
# TODO: find a programmatic solution for model arch spec instead of hardcoding llama2
|
||||
self.file_spec = "_".join(
|
||||
[
|
||||
self.safe_name,
|
||||
self.precision,
|
||||
]
|
||||
)
|
||||
if self.quantization != "None":
|
||||
self.file_spec += "_" + self.quantization
|
||||
|
||||
if external_weights is not None:
|
||||
self.external_weight_file = get_resource_path(
|
||||
self.safe_name + "." + external_weights
|
||||
self.file_spec + "." + external_weights
|
||||
)
|
||||
|
||||
if streaming_llm:
|
||||
# Add streaming suffix to file spec after setting external weights filename.
|
||||
self.file_spec += "_streaming"
|
||||
self.streaming_llm = streaming_llm
|
||||
|
||||
self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile")
|
||||
# TODO: Tag vmfb with target triple of device instead of HAL backend
|
||||
self.vmfb_name = get_resource_path(
|
||||
f"{self.file_spec}_{self.backend}.vmfb.tempfile"
|
||||
)
|
||||
self.max_tokens = llm_model_map[model_name]["max_tokens"]
|
||||
self.iree_module_dict = None
|
||||
self.use_system_prompt = use_system_prompt
|
||||
self.global_iter = 0
|
||||
self.prev_token_len = 0
|
||||
self.first_input = True
|
||||
if self.external_weight_file is not None:
|
||||
if not os.path.exists(self.external_weight_file):
|
||||
print(
|
||||
f"External weight file {self.external_weight_file} does not exist. Generating..."
|
||||
)
|
||||
gen_external_params(
|
||||
hf_model_name=self.hf_model_name,
|
||||
quantization=self.quantization,
|
||||
weight_path=self.external_weight_file,
|
||||
hf_auth_token=hf_auth_token,
|
||||
precision=self.precision,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
|
||||
)
|
||||
if os.path.exists(self.vmfb_name) and (
|
||||
external_weights is None or os.path.exists(str(self.external_weight_file))
|
||||
):
|
||||
self.iree_module_dict = dict()
|
||||
(
|
||||
self.iree_module_dict["vmfb"],
|
||||
self.iree_module_dict["config"],
|
||||
self.iree_module_dict["temp_file_to_unlink"],
|
||||
) = load_vmfb_using_mmap(
|
||||
self.vmfb_name,
|
||||
device,
|
||||
device_idx=0,
|
||||
rt_flags=[],
|
||||
external_weight_file=self.external_weight_file,
|
||||
self.runner = vmfbRunner(
|
||||
device=self.driver,
|
||||
vmfb_path=self.vmfb_name,
|
||||
external_weight_path=self.external_weight_file,
|
||||
)
|
||||
if self.streaming_llm:
|
||||
self.model = self.runner.ctx.modules.streaming_state_update
|
||||
else:
|
||||
self.model = self.runner.ctx.modules.state_update
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_name,
|
||||
use_fast=False,
|
||||
@@ -82,7 +148,9 @@ class LanguageModel:
|
||||
hf_auth_token,
|
||||
compile_to="torch",
|
||||
external_weights=external_weights,
|
||||
external_weight_file=self.external_weight_file,
|
||||
precision=self.precision,
|
||||
quantization=self.quantization,
|
||||
streaming_llm=self.streaming_llm,
|
||||
)
|
||||
with open(self.tempfile_name, "w+") as f:
|
||||
f.write(self.torch_ir)
|
||||
@@ -99,19 +167,37 @@ class LanguageModel:
|
||||
|
||||
def compile(self) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
self.iree_module_dict = get_iree_compiled_module(
|
||||
# ONLY architecture/api-specific compile-time flags for each backend, if needed.
|
||||
# hf_model_id-specific global flags currently in model map.
|
||||
flags = []
|
||||
if "cpu" in self.backend:
|
||||
flags.extend(
|
||||
[
|
||||
"--iree-global-opt-enable-quantized-matmul-reassociation",
|
||||
]
|
||||
)
|
||||
elif self.backend == "vulkan":
|
||||
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
|
||||
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
self.tempfile_name,
|
||||
device=self.device,
|
||||
mmap=True,
|
||||
frontend="torch",
|
||||
external_weight_file=self.external_weight_file,
|
||||
model_config_path=None,
|
||||
extra_args=flags,
|
||||
write_to=self.vmfb_name,
|
||||
extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"],
|
||||
)
|
||||
# TODO: delete the temp file
|
||||
self.runner = vmfbRunner(
|
||||
device=self.driver,
|
||||
vmfb_path=self.vmfb_name,
|
||||
external_weight_path=self.external_weight_file,
|
||||
)
|
||||
if self.streaming_llm:
|
||||
self.model = self.runner.ctx.modules.streaming_state_update
|
||||
else:
|
||||
self.model = self.runner.ctx.modules.state_update
|
||||
|
||||
def sanitize_prompt(self, prompt):
|
||||
print(prompt)
|
||||
if isinstance(prompt, list):
|
||||
prompt = list(chain.from_iterable(prompt))
|
||||
prompt = " ".join([x for x in prompt if isinstance(x, str)])
|
||||
@@ -119,10 +205,12 @@ class LanguageModel:
|
||||
prompt = prompt.replace("\t", " ")
|
||||
prompt = prompt.replace("\r", " ")
|
||||
if self.use_system_prompt and self.global_iter == 0:
|
||||
prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt
|
||||
prompt += " [/INST]"
|
||||
print(prompt)
|
||||
return prompt
|
||||
prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt)
|
||||
print(prompt)
|
||||
return prompt
|
||||
else:
|
||||
print(prompt)
|
||||
return f"{B_INST} {prompt} {E_INST}"
|
||||
|
||||
def chat(self, prompt):
|
||||
prompt = self.sanitize_prompt(prompt)
|
||||
@@ -134,26 +222,40 @@ class LanguageModel:
|
||||
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
st_time = time.time()
|
||||
if iter == 0:
|
||||
device_inputs = [
|
||||
ireert.asdevicearray(
|
||||
self.iree_module_dict["config"].device, input_tensor
|
||||
)
|
||||
]
|
||||
token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
|
||||
if self.streaming_llm:
|
||||
token_slice = max(self.prev_token_len - 1, 0)
|
||||
input_tensor = input_tensor[:, token_slice:]
|
||||
if self.streaming_llm and self.model["get_seq_step"]() > 600:
|
||||
print("Evicting cache space!")
|
||||
self.model["evict_kvcache_space"]()
|
||||
token_len = input_tensor.shape[-1]
|
||||
device_inputs = [
|
||||
ireert.asdevicearray(self.runner.config.device, input_tensor)
|
||||
]
|
||||
if self.first_input or not self.streaming_llm:
|
||||
st_time = time.time()
|
||||
token = self.model["run_initialize"](*device_inputs)
|
||||
total_time = time.time() - st_time
|
||||
token_len += 1
|
||||
self.first_input = False
|
||||
else:
|
||||
device_inputs = [
|
||||
ireert.asdevicearray(
|
||||
self.iree_module_dict["config"].device,
|
||||
token,
|
||||
)
|
||||
]
|
||||
token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs)
|
||||
st_time = time.time()
|
||||
token = self.model["run_cached_initialize"](*device_inputs)
|
||||
total_time = time.time() - st_time
|
||||
token_len += 1
|
||||
|
||||
total_time = time.time() - st_time
|
||||
history.append(format_out(token))
|
||||
yield self.tokenizer.decode(history), total_time
|
||||
while format_out(token) != llm_model_map["llama2_7b"]["stop_token"]:
|
||||
dec_time = time.time()
|
||||
if self.streaming_llm and self.model["get_seq_step"]() > 600:
|
||||
print("Evicting cache space!")
|
||||
self.model["evict_kvcache_space"]()
|
||||
token = self.model["run_forward"](token)
|
||||
history.append(format_out(token))
|
||||
total_time = time.time() - dec_time
|
||||
yield self.tokenizer.decode(history), total_time
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
break
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
import logging
|
||||
import unittest
|
||||
import json
|
||||
from apps.shark_studio.api.llm import LanguageModel
|
||||
import gc
|
||||
|
||||
from apps.shark_studio.api.llm import LanguageModel
|
||||
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
|
||||
@@ -28,12 +30,13 @@ class SDAPITest(unittest.TestCase):
|
||||
print(i)
|
||||
|
||||
class LLMAPITest(unittest.TestCase):
|
||||
def testLLMSimple(self):
|
||||
def test01_LLMSmall(self):
|
||||
lm = LanguageModel(
|
||||
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||
"TinyPixel/small-llama2",
|
||||
hf_auth_token=None,
|
||||
device="cpu-task",
|
||||
external_weights="safetensors",
|
||||
device="cpu",
|
||||
precision="fp32",
|
||||
quantization="None",
|
||||
)
|
||||
count = 0
|
||||
for msg, _ in lm.chat("hi, what are you?"):
|
||||
@@ -42,9 +45,11 @@ class LLMAPITest(unittest.TestCase):
|
||||
count += 1
|
||||
continue
|
||||
assert (
|
||||
msg.strip(" ") == "Hello"
|
||||
), f"LLM API failed to return correct response, expected 'Hello', received {msg}"
|
||||
msg.strip(" ") == "Turkish Turkish Turkish"
|
||||
), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}"
|
||||
break
|
||||
del lm
|
||||
gc.collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -11,19 +11,23 @@ from apps.shark_studio.api.llm import (
|
||||
)
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def append_bot_prompt(history, input_prompt):
|
||||
user_prompt = f"{input_prompt} {E_SYS} {E_SYS}"
|
||||
history += user_prompt
|
||||
return history
|
||||
|
||||
|
||||
language_model = None
|
||||
|
||||
|
||||
def create_prompt(model_name, history, prompt_prefix):
|
||||
return ""
|
||||
|
||||
|
||||
def get_default_config():
|
||||
return False
|
||||
|
||||
@@ -39,9 +43,13 @@ def chat_fn(
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
streaming_llm,
|
||||
cli=False,
|
||||
):
|
||||
global language_model
|
||||
if streaming_llm and prompt_prefix == "Clear":
|
||||
language_model = None
|
||||
return "Clearing history...", ""
|
||||
if language_model is None:
|
||||
history[-1][-1] = "Getting the model ready..."
|
||||
yield history, ""
|
||||
@@ -50,8 +58,8 @@ def chat_fn(
|
||||
device=device,
|
||||
precision=precision,
|
||||
external_weights="safetensors",
|
||||
external_weight_file="llama2_7b.safetensors",
|
||||
use_system_prompt=prompt_prefix,
|
||||
streaming_llm=streaming_llm,
|
||||
)
|
||||
history[-1][-1] = "Getting the model ready... Done"
|
||||
yield history, ""
|
||||
@@ -61,7 +69,7 @@ def chat_fn(
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, exec_time in language_model.chat(history):
|
||||
history[-1][-1] = text
|
||||
history[-1][-1] = f"{text}{E_SYS}"
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
@@ -73,101 +81,6 @@ def chat_fn(
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
return None
|
||||
print(f"Input keys : {InputData.keys()}")
|
||||
# print(f"model : {InputData['model']}")
|
||||
is_chat_completion_api = (
|
||||
"messages" in InputData.keys()
|
||||
) # else it is the legacy `completion` api
|
||||
# For Debugging input data from API
|
||||
# if is_chat_completion_api:
|
||||
# print(f"message -> role : {InputData['messages'][0]['role']}")
|
||||
# print(f"message -> content : {InputData['messages'][0]['content']}")
|
||||
# else:
|
||||
# print(f"prompt : {InputData['prompt']}")
|
||||
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
|
||||
global vicuna_model
|
||||
model_name = InputData["model"] if "model" in InputData.keys() else "codegen"
|
||||
model_path = llm_model_map[model_name]
|
||||
device = "cpu-task"
|
||||
precision = "fp16"
|
||||
max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"]
|
||||
if max_toks is None:
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# make it working for codegen first
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
if is_chat_completion_api:
|
||||
# TODO: add funtionality for multiple messages
|
||||
prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")])
|
||||
else:
|
||||
prompt = InputData["prompt"]
|
||||
print("prompt = ", prompt)
|
||||
|
||||
res = vicuna_model.generate(prompt)
|
||||
res_op = None
|
||||
for op in res:
|
||||
res_op = op
|
||||
|
||||
if is_chat_completion_api:
|
||||
choices = [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": res_op, # since we are yeilding the result
|
||||
},
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
else:
|
||||
choices = [
|
||||
{
|
||||
"text": res_op,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
|
||||
return {
|
||||
"id": end_time,
|
||||
"object": "chat.completion" if is_chat_completion_api else "text_completion",
|
||||
"created": int(end_time),
|
||||
"choices": choices,
|
||||
}
|
||||
|
||||
|
||||
def view_json_file(file_obj):
|
||||
content = ""
|
||||
with open(file_obj.name, "r") as fopen:
|
||||
@@ -198,7 +111,7 @@ with gr.Blocks(title="Chat") as chat_element:
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int4",
|
||||
value="fp32",
|
||||
choices=[
|
||||
# "int4",
|
||||
# "int8",
|
||||
@@ -211,12 +124,18 @@ with gr.Blocks(title="Chat") as chat_element:
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=False,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
streaming_llm = gr.Checkbox(
|
||||
label="Run in streaming mode (requires recompilation)",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=False,
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
@@ -260,6 +179,7 @@ with gr.Blocks(title="Chat") as chat_element:
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
streaming_llm,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
@@ -281,6 +201,7 @@ with gr.Blocks(title="Chat") as chat_element:
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
streaming_llm,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
@@ -293,4 +214,19 @@ with gr.Blocks(title="Chat") as chat_element:
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
clear.click(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
clear,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
streaming_llm,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
).then(lambda: None, None, [chatbot], queue=False)
|
||||
|
||||
12
apps/shark_studio/web/utils.py
Normal file
12
apps/shark_studio/web/utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
return ["cpu-task"]
|
||||
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
|
||||
return os.path.join(base_path, relative_path)
|
||||
Reference in New Issue
Block a user