Compare commits

...

2 Commits

Author SHA1 Message Date
Ean Garvey
ee0233e370 Fix formatting. 2023-11-13 20:01:28 -06:00
Daniel Garvey
a3deeec870 Dan shark studio (#1970)
* Fix issue in Falcon-GPTQ

* initial webui and llama2

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-11-13 19:07:28 -06:00
7 changed files with 1055 additions and 4 deletions

View File

@@ -0,0 +1,91 @@
from turbine_models.custom_models import stateless_llama
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.api.utils import get_resource_path
import iree.runtime as ireert
import gc
import torch
llm_model_map = {
"llama2_7b": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"stop_token": 2,
"max_tokens": 4096,
}
}
class LanguageModel:
def __init__(
self, model_name, hf_auth_token=None, device=None, precision="fp32"
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
)
# TODO: delete the temp file
def chat(self, prompt):
history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
if iter == 0:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
)
else:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
)
history.append(token)
yield self.tokenizer.decode(history)
if token == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output
if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
device="cpu-task",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)

View File

@@ -0,0 +1,14 @@
import os
import sys
def get_available_devices():
return ["cpu-task"]
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)

View File

@@ -0,0 +1,428 @@
from multiprocessing import Process, freeze_support
import os
import sys
import logging
from ui.chat import chat_element
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
# from apps.stable_diffusion.src import args, clear_all
# import apps.stable_diffusion.web.utils.global_obj as global_obj
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":
# if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.stable_diffusion.web.ui import (
# txt2img_api,
# img2img_api,
# upscaler_api,
# inpaint_api,
# outpaint_api,
# llm_chat_api,
# )
#
# from fastapi import FastAPI, APIRouter
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# app = FastAPI()
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# app.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
# app.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# app.include_router(APIRouter())
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
#
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
# from apps.stable_diffusion.web.utils.gradio_configs import (
# config_gradio_tmp_imgs_folder,
# )
# config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# from apps.stable_diffusion.web.ui import (
# txt2img_web,
# txt2img_custom_model,
# txt2img_gallery,
# txt2img_png_info_img,
# txt2img_status,
# txt2img_sendto_img2img,
# txt2img_sendto_inpaint,
# txt2img_sendto_outpaint,
# txt2img_sendto_upscaler,
## h2ogpt_upload,
## h2ogpt_web,
# img2img_web,
# img2img_custom_model,
# img2img_gallery,
# img2img_init_image,
# img2img_status,
# img2img_sendto_inpaint,
# img2img_sendto_outpaint,
# img2img_sendto_upscaler,
# inpaint_web,
# inpaint_custom_model,
# inpaint_gallery,
# inpaint_init_image,
# inpaint_status,
# inpaint_sendto_img2img,
# inpaint_sendto_outpaint,
# inpaint_sendto_upscaler,
# outpaint_web,
# outpaint_custom_model,
# outpaint_gallery,
# outpaint_init_image,
# outpaint_status,
# outpaint_sendto_img2img,
# outpaint_sendto_inpaint,
# outpaint_sendto_upscaler,
# upscaler_web,
# upscaler_custom_model,
# upscaler_gallery,
# upscaler_init_image,
# upscaler_status,
# upscaler_sendto_img2img,
# upscaler_sendto_inpaint,
# upscaler_sendto_outpaint,
## lora_train_web,
## model_web,
## model_config_web,
# hf_models,
# modelmanager_sendto_txt2img,
# modelmanager_sendto_img2img,
# modelmanager_sendto_inpaint,
# modelmanager_sendto_outpaint,
# modelmanager_sendto_upscaler,
# stablelm_chat,
# minigpt4_web,
# outputgallery_web,
# outputgallery_tab_select,
# outputgallery_watch,
# outputgallery_filename,
# outputgallery_sendto_txt2img,
# outputgallery_sendto_img2img,
# outputgallery_sendto_inpaint,
# outputgallery_sendto_outpaint,
# outputgallery_sendto_upscaler,
# )
# init global sd pipeline and config
# global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
# with gr.TabItem(label="Text-to-Image", id=0):
# txt2img_web.render()
# with gr.TabItem(label="Image-to-Image", id=1):
# img2img_web.render()
# with gr.TabItem(label="Inpainting", id=2):
# inpaint_web.render()
# with gr.TabItem(label="Outpainting", id=3):
# outpaint_web.render()
# with gr.TabItem(label="Upscaler", id=4):
# upscaler_web.render()
# if args.output_gallery:
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
# outputgallery_web.render()
# # extra output gallery configuration
# outputgallery_tab_select(og_tab.select)
# outputgallery_watch(
# [
# txt2img_status,
# img2img_status,
# inpaint_status,
# outpaint_status,
# upscaler_status,
# ]
# )
## with gr.TabItem(label="Model Manager", id=6):
## model_web.render()
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
## lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=0):
chat_element.render()
## with gr.TabItem(
## label="Generate Sharding Config (Experimental)", id=9
## ):
## model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
# send to buttons
# register_button_click(
# txt2img_sendto_img2img,
# 1,
# [txt2img_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_inpaint,
# 2,
# [txt2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_outpaint,
# 3,
# [txt2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_upscaler,
# 4,
# [txt2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_inpaint,
# 2,
# [img2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_outpaint,
# 3,
# [img2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_upscaler,
# 4,
# [img2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_img2img,
# 1,
# [inpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_outpaint,
# 3,
# [inpaint_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_upscaler,
# 4,
# [inpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_img2img,
# 1,
# [outpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_inpaint,
# 2,
# [outpaint_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_upscaler,
# 4,
# [outpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_img2img,
# 1,
# [upscaler_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_inpaint,
# 2,
# [upscaler_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_outpaint,
# 3,
# [upscaler_gallery],
# [outpaint_init_image, tabs],
# )
# if args.output_gallery:
# register_outputgallery_button(
# outputgallery_sendto_txt2img,
# 0,
# [outputgallery_filename],
# [txt2img_png_info_img, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_img2img,
# 1,
# [outputgallery_filename],
# [img2img_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_inpaint,
# 2,
# [outputgallery_filename],
# [inpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_outpaint,
# 3,
# [outputgallery_filename],
# [outpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_upscaler,
# 4,
# [outputgallery_filename],
# [upscaler_init_image, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_txt2img,
# 0,
# [hf_models],
# [txt2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_img2img,
# 1,
# [hf_models],
# [img2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_inpaint,
# 2,
# [hf_models],
# [inpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_outpaint,
# 3,
# [hf_models],
# [outpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_upscaler,
# 4,
# [hf_models],
# [upscaler_custom_model, tabs],
# )
sd_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
sd_web.launch(
share=True,
inbrowser=True,
server_name="0.0.0.0",
server_port=11911, # args.server_port,
)

View File

View File

@@ -0,0 +1,517 @@
import gradio as gr
import os
from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.utils import (
get_available_devices,
)
from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
language_model = None
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_13b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_70b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant gives helpful, detailed, and "
"polite answers to the user's questions.\n"
),
}
def create_prompt(model_name, history, prompt_prefix):
return ""
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
return msg
def get_default_config():
return False
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
# model_vmfb_key = ""
def chat_fn(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global language_model
if language_model is None:
language_model = LanguageModel(
model, device=device, precision=precision
)
language_model.chat(prompt_prefix)
return "", ""
global past_key_values
global model_vmfb_key
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")
elif "rocm" in device:
# add iree rocm flags
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")
prompt = create_prompt(model_name, history, prompt_prefix)
partial_text = ""
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
partial_text += text + " "
history[-1][1] = partial_text
yield history, f"Prefill: {prefill_time:.2f}"
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
return history, ""
def llm_chat_api(InputData: dict):
return None
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
# if is_chat_completion_api:
# print(f"message -> role : {InputData['messages'][0]['role']}")
# print(f"message -> content : {InputData['messages'][0]['content']}")
# else:
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = (
InputData["model"] if "model" in InputData.keys() else "codegen"
)
model_path = llm_model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = (
None
if "max_tokens" not in InputData.keys()
else InputData["max_tokens"]
)
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
# make it working for codegen first
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(
model_name, [(InputData["messages"][0]["content"], "")]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
res = vicuna_model.generate(prompt)
res_op = None
for op in res:
res_op = op
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion"
if is_chat_completion_api
else "text_completion",
"created": int(end_time),
"choices": choices,
}
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chat") as chat_element:
with gr.Row():
model_choices = list(llm_model_map.keys())
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = get_available_devices()
enabled = True
if len(supported_devices) == 0:
supported_devices = ["cpu-task"]
supported_devices = [x for x in supported_devices if "sync" not in x]
device = gr.Dropdown(
label="Device",
value=supported_devices[0],
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
value="int4",
choices=[
# "int4",
# "int8",
# "fp16",
"fp32",
],
visible=False,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
interactive=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -17,6 +17,7 @@ pytest-forked
Pillow
parameterized
#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main
# Add transformers, diffusers and scipy since it most commonly used
tokenizers==0.13.3
transformers
@@ -49,4 +50,4 @@ pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea

View File

@@ -16,7 +16,6 @@ import numpy as np
import os
import re
import tempfile
import time
from pathlib import Path
import iree.runtime as ireert
@@ -63,8 +62,7 @@ def get_iree_device_args(device, extra_args=[]):
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
+ ["--iree-flow-enable-quantized-matmul-reassociation"]
+ ["--iree-llvmcpu-enable-quantized-matmul-reassociation"]
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
@@ -321,6 +319,8 @@ def compile_module_to_flatbuffer(
input_type = "tosa"
elif frontend in ["tm_tensor"]:
input_type = ireec.InputType.TM_TENSOR
elif frontend in ["torch", "pytorch"]:
input_type = "torch"
if compile_str:
flatbuffer_blob = ireec.compile_str(