Tweaks to chatbot

This commit is contained in:
Ean Garvey
2024-05-30 18:30:40 -04:00
parent 222f387705
commit 18ecd61cce
2 changed files with 28 additions and 2 deletions

View File

@@ -3,8 +3,10 @@ from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import get_resource_path
from apps.shark_studio.web.utils.file_utils import get_resource_path, get_checkpoints_path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from urllib.request import urlopen
import iree.runtime as ireert
from itertools import chain
import gc
@@ -65,6 +67,7 @@ class LanguageModel:
use_system_prompt=True,
streaming_llm=False,
):
_, _, self.triple = parse_device(device)
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.device = device.split("=>")[-1].strip()
self.backend = self.device.split("://")[0]
@@ -165,6 +168,7 @@ class LanguageModel:
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
@@ -194,11 +198,23 @@ class LanguageModel:
)
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
elif self.backend == "rocm":
flags.extend([
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
])
if "gfx9" in self.triple:
flags.extend([
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true"
])
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
device=self.device,
frontend="torch",
frontend="auto",
model_config_path=None,
extra_args=flags,
write_to=self.vmfb_name,
@@ -328,6 +344,15 @@ class LanguageModel:
self.global_iter += 1
return result_output, total_time
def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
return spec_path
with open(spec_path, "w") as f:
f.write(attn_spec)
return spec_path
def llm_chat_api(InputData: dict):
from datetime import datetime as dt

View File

@@ -138,6 +138,7 @@ with gr.Blocks(title="Chat") as chat_element:
label="Run in streaming mode (requires recompilation)",
value=True,
interactive=False,
visible=False,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",