mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Enable downloading vmfb/mlir for webui (#1807)
This commit is contained in:
@@ -1272,7 +1272,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.cache_vicunas = cache_vicunas
|
||||
self.compile(download_vmfb)
|
||||
self.compile()
|
||||
|
||||
def get_model_path(self, suffix="mlir"):
|
||||
safe_device = self.device.split("-")[0]
|
||||
@@ -1404,13 +1404,13 @@ class UnshardedVicuna(VicunaBase):
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
def compile(self, download_vmfb=False):
|
||||
def compile(self):
|
||||
# Testing : DO NOT Download Vmfbs if not found. Modify later
|
||||
# download vmfbs for A100
|
||||
print(
|
||||
f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}"
|
||||
)
|
||||
if not self.vicuna_vmfb_path.exists() and download_vmfb:
|
||||
if not self.vicuna_vmfb_path.exists() and self.download_vmfb:
|
||||
print(
|
||||
f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}"
|
||||
)
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}",
|
||||
self.vicuna_vmfb_path.absolute(),
|
||||
@@ -1423,245 +1423,237 @@ class UnshardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
|
||||
return
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.vicuna_vmfb_path.absolute()}")
|
||||
if self.vicuna_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
else:
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.vicuna_mlir_path.absolute()}"
|
||||
)
|
||||
mlir_generated = False
|
||||
if self.load_mlir_from_shark_tank:
|
||||
# download MLIR from shark tank
|
||||
for suffix in ["mlirbc", "mlir"]:
|
||||
self.vicuna_mlir_path = self.get_model_path(suffix)
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
|
||||
self.vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.vicuna_mlir_path.exists():
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
mlir_generated = True
|
||||
break
|
||||
self.vicuna_mlir_path = self.get_model_path("mlir")
|
||||
if not mlir_generated:
|
||||
print(
|
||||
f"[DEBUG] failed to download {self.vicuna_mlir_path.name} from shark tank"
|
||||
)
|
||||
print(f"[DEBUG] vmfb not found")
|
||||
mlir_generated = False
|
||||
for suffix in ["mlirbc", "mlir"]:
|
||||
self.vicuna_mlir_path = self.get_model_path(suffix)
|
||||
if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank:
|
||||
print(
|
||||
f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}"
|
||||
)
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
|
||||
self.vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.vicuna_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
mlir_generated = True
|
||||
break
|
||||
|
||||
if not mlir_generated:
|
||||
print("[DEBUG] generating mlir on device")
|
||||
# Select a compilation prompt such that the resulting input_ids
|
||||
# from the model's tokenizer has shape [1, 19]
|
||||
if self.model_name == "codegen":
|
||||
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
|
||||
else:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] mlir not found")
|
||||
print("[DEBUG] generating mlir on device")
|
||||
# Select a compilation prompt such that the resulting input_ids
|
||||
# from the model's tokenizer has shape [1, 19]
|
||||
if self.model_name == "codegen":
|
||||
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
|
||||
else:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
|
||||
first_model_path = f"first_{self.model_name}_{self.precision}.mlir"
|
||||
if Path(first_model_path).exists():
|
||||
print(f"loading {first_model_path}")
|
||||
with open(Path(first_model_path), "r") as f:
|
||||
first_module = f.read()
|
||||
first_model_path = f"first_{self.model_name}_{self.precision}.mlir"
|
||||
if Path(first_model_path).exists():
|
||||
print(f"loading {first_model_path}")
|
||||
with open(Path(first_model_path), "r") as f:
|
||||
first_module = f.read()
|
||||
else:
|
||||
# generate first vicuna
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=is_f16,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
firstVicunaCompileInput = list(firstVicunaCompileInput)
|
||||
firstVicunaCompileInput[
|
||||
0
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
firstVicunaCompileInput[0], dynamic_axes=[1]
|
||||
)
|
||||
|
||||
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
|
||||
first_module = None
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
first_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
# generate first vicuna
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del firstVicunaCompileInput
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
"[DEBUG] successfully generated first vicuna linalg mlir"
|
||||
)
|
||||
first_module = self.write_in_dynamic_inputs0(
|
||||
str(first_module), dynamic_input_size=19
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(first_model_path, "w+") as f:
|
||||
f.write(first_module)
|
||||
print("Finished writing IR after dynamic")
|
||||
|
||||
print(f"[DEBUG] Starting generation of second llama")
|
||||
second_model_path = f"second_{self.model_name}_{self.precision}.mlir"
|
||||
if Path(second_model_path).exists():
|
||||
print(f"loading {second_model_path}")
|
||||
with open(Path(second_model_path), "r") as f:
|
||||
second_module = f.read()
|
||||
else:
|
||||
# generate second vicuna
|
||||
compilation_input_ids = torch.zeros(
|
||||
[1, 1], dtype=torch.int64
|
||||
)
|
||||
if self.model_name == "llama2_13b":
|
||||
dim1 = 40
|
||||
total_tuple = 80
|
||||
elif self.model_name == "llama2_70b":
|
||||
dim1 = 8
|
||||
total_tuple = 160
|
||||
else:
|
||||
dim1 = 32
|
||||
total_tuple = 64
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
|
||||
for _ in range(total_tuple)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=is_f16,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
del model
|
||||
firstVicunaCompileInput = list(firstVicunaCompileInput)
|
||||
firstVicunaCompileInput[
|
||||
0
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
firstVicunaCompileInput[0], dynamic_axes=[1]
|
||||
)
|
||||
|
||||
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
|
||||
first_module = None
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
first_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del firstVicunaCompileInput
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
"[DEBUG] successfully generated first vicuna linalg mlir"
|
||||
)
|
||||
first_module = self.write_in_dynamic_inputs0(
|
||||
str(first_module), dynamic_input_size=19
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(first_model_path, "w+") as f:
|
||||
f.write(first_module)
|
||||
print("Finished writing IR after dynamic")
|
||||
print(f"[DEBUG] Starting generation of second llama")
|
||||
second_model_path = f"second_{self.model_name}_{self.precision}.mlir"
|
||||
if Path(second_model_path).exists():
|
||||
print(f"loading {second_model_path}")
|
||||
with open(Path(second_model_path), "r") as f:
|
||||
second_module = f.read()
|
||||
else:
|
||||
# generate second vicuna
|
||||
compilation_input_ids = torch.zeros(
|
||||
[1, 1], dtype=torch.int64
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if self.model_name == "llama2_13b":
|
||||
dim1 = 40
|
||||
total_tuple = 80
|
||||
elif self.model_name == "llama2_70b":
|
||||
dim1 = 8
|
||||
total_tuple = 160
|
||||
else:
|
||||
dim1 = 32
|
||||
total_tuple = 64
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
|
||||
for _ in range(total_tuple)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
is_f16=is_f16,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False] + [True] * total_tuple,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
f16_input_mask=[False] + [True] * total_tuple,
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
second_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del secondVicunaCompileInput
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
"[DEBUG] successfully generated second vicuna linalg mlir"
|
||||
)
|
||||
second_module = self.write_in_dynamic_inputs1(
|
||||
str(second_module)
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(second_model_path, "w+") as f:
|
||||
f.write(second_module)
|
||||
print("Finished writing IR after dynamic")
|
||||
|
||||
combined_module = self.combine_mlir_scripts(
|
||||
first_module,
|
||||
second_module,
|
||||
self.vicuna_mlir_path,
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
is_f16=is_f16,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False] + [True] * total_tuple,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del first_module, second_module
|
||||
del model
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
f16_input_mask=[False] + [True] * total_tuple,
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
second_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del secondVicunaCompileInput
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
"[DEBUG] successfully generated second vicuna linalg mlir"
|
||||
)
|
||||
second_module = self.write_in_dynamic_inputs1(
|
||||
str(second_module)
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(second_model_path, "w+") as f:
|
||||
f.write(second_module)
|
||||
print("Finished writing IR after dynamic")
|
||||
|
||||
combined_module = self.combine_mlir_scripts(
|
||||
first_module,
|
||||
second_module,
|
||||
self.vicuna_mlir_path,
|
||||
)
|
||||
del first_module, second_module
|
||||
|
||||
print(self.device)
|
||||
if "rocm" in self.device:
|
||||
|
||||
@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
gr.Markdown(description)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
with gr.Column():
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(
|
||||
value="Upload & Start Chat",
|
||||
|
||||
@@ -160,14 +160,15 @@ def chat(
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
|
||||
global vicuna_model
|
||||
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
@@ -177,6 +178,8 @@ def chat(
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
@@ -194,18 +197,6 @@ def chat(
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
@@ -237,6 +228,8 @@ def chat(
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
# else:
|
||||
@@ -360,6 +353,8 @@ def llm_chat_api(InputData: dict):
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
@@ -430,15 +425,14 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
# print(supported_devices)
|
||||
devices = gr.Dropdown(
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
# multiselect=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
@@ -450,7 +444,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
@@ -485,7 +485,15 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
inputs=[
|
||||
system_msg,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
@@ -493,7 +501,15 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
inputs=[
|
||||
system_msg,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user