remove shark 1.0 tests, add support for 2.0 llm

* add support for external weights

* add tests and edit deps
This commit is contained in:
Daniel Garvey
2023-12-14 21:44:37 -06:00
committed by GitHub
parent f692a012e1
commit ebfcfec338
16 changed files with 377 additions and 576 deletions

View File

@@ -1,164 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Models on Shark Runtime
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
suite: [cpu,cuda,vulkan]
python-version: ["3.11"]
include:
- os: ubuntu-latest
suite: lint
- os: MacStudio
suite: metal
exclude:
- os: ubuntu-latest
suite: vulkan
- os: ubuntu-latest
suite: cuda
- os: ubuntu-latest
suite: cpu
- os: MacStudio
suite: cuda
- os: MacStudio
suite: cpu
- os: MacStudio
suite: vulkan
- os: icelake
suite: vulkan
- os: icelake
suite: cuda
- os: a100
suite: cpu
- os: 7950x
suite: cpu
- os: 7950x
suite: cuda
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set Environment Variables
if: matrix.os != '7950x'
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
run: |
# See https://github.com/actions/setup-python/issues/433
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
#cache: 'pip'
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check .
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark=native --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
python build_tools/vicuna_testing.py
- name: Validate Models on NVIDIA GPU
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark=native --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
echo $PATH
pip list | grep -E "torch|iree"
# disabled due to a low-visibility memory issue with pytest on macos.
# pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest -k vulkan -s --ci
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
python build_tools/stable_diffusion_testing.py --device=vulkan

86
.github/workflows/test-studio.yml vendored Normal file
View File

@@ -0,0 +1,86 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Shark Studio
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [nodai-ubuntu-builder-large]
suite: [cpu] #,cuda,vulkan]
python-version: ["3.11"]
include:
- os: nodai-ubuntu-builder-large
suite: lint
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set Environment Variables
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
run: |
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check apps/shark_studio
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
python${{ matrix.python-version }} -m venv shark.venv
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
pip uninstall -y torch
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python apps/shark_studio/tests/api_test.py

View File

@@ -1,9 +1,16 @@
from turbine_models.custom_models import stateless_llama
from shark.iree_utils.compile_utils import get_iree_compiled_module
import time
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
)
from apps.shark_studio.api.utils import get_resource_path
import iree.runtime as ireert
from itertools import chain
import gc
import os
import torch
from transformers import AutoTokenizer
llm_model_map = {
"llama2_7b": {
@@ -11,81 +18,161 @@ llm_model_map = {
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"stop_token": 2,
"max_tokens": 4096,
}
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
}
class LanguageModel:
def __init__(
self, model_name, hf_auth_token=None, device=None, precision="fp32"
self,
model_name,
hf_auth_token=None,
device=None,
precision="fp32",
external_weights=None,
use_system_prompt=True,
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.vmfb_name = get_resource_path("llm.vmfb.tempfile")
self.device = device
self.precision = precision
self.safe_name = self.hf_model_name.strip("/").replace("/", "_")
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()
self.external_weight_file = None
if external_weights is not None:
self.external_weight_file = get_resource_path(
self.safe_name + "." + external_weights
)
self.use_system_prompt = use_system_prompt
self.global_iter = 0
if os.path.exists(self.vmfb_name) and (
external_weights is None or os.path.exists(str(self.external_weight_file))
):
self.iree_module_dict = dict()
(
self.iree_module_dict["vmfb"],
self.iree_module_dict["config"],
self.iree_module_dict["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.vmfb_name,
device,
device_idx=0,
rt_flags=[],
external_weight_file=self.external_weight_file,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
external_weight_file=self.external_weight_file,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.compile()
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
self.compile()
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
self.tempfile_name,
device=self.device,
mmap=True,
frontend="torch",
external_weight_file=self.external_weight_file,
write_to=self.vmfb_name,
)
# TODO: delete the temp file
def sanitize_prompt(self, prompt):
print(prompt)
if isinstance(prompt, list):
prompt = list(chain.from_iterable(prompt))
prompt = " ".join([x for x in prompt if isinstance(x, str)])
prompt = prompt.replace("\n", " ")
prompt = prompt.replace("\t", " ")
prompt = prompt.replace("\r", " ")
if self.use_system_prompt and self.global_iter == 0:
prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt
prompt += " [/INST]"
print(prompt)
return prompt
def chat(self, prompt):
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
def format_out(results):
return torch.tensor(results.to_host()[0][0])
history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
st_time = time.time()
if iter == 0:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
)
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
else:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
)
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device,
token,
)
]
token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs)
history.append(token)
yield self.tokenizer.decode(history)
total_time = time.time() - st_time
history.append(format_out(token))
yield self.tokenizer.decode(history), total_time
if token == llm_model_map["llama2_7b"]["stop_token"]:
if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output
self.global_iter += 1
return result_output, total_time
if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
for i in lm.chat("hi, what are you?"):
print(i)

View File

@@ -8,7 +8,5 @@ def get_available_devices():
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)

View File

@@ -0,0 +1,34 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import logging
import unittest
from apps.shark_studio.api.llm import LanguageModel
class LLMAPITest(unittest.TestCase):
def testLLMSimple(self):
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
)
count = 0
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == "Hello"
), f"LLM API failed to return correct response, expected 'Hello', received {msg}"
break
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

View File

@@ -93,9 +93,7 @@ if __name__ == "__main__":
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
@@ -201,7 +199,7 @@ if __name__ == "__main__":
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they

View File

@@ -1,4 +1,5 @@
import gradio as gr
import time
import os
from pathlib import Path
from datetime import datetime as dt
@@ -21,104 +22,12 @@ def user(message, history):
language_model = None
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_13b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_70b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant gives helpful, detailed, and "
"polite answers to the user's questions.\n"
),
}
def create_prompt(model_name, history, prompt_prefix):
return ""
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
return msg
def get_default_config():
return False
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
# model_vmfb_key = ""
@@ -133,153 +42,37 @@ def chat_fn(
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global language_model
if language_model is None:
history[-1][-1] = "Getting the model ready..."
yield history, ""
language_model = LanguageModel(
model, device=device, precision=precision
model,
device=device,
precision=precision,
external_weights="safetensors",
external_weight_file="llama2_7b.safetensors",
use_system_prompt=prompt_prefix,
)
language_model.chat(prompt_prefix)
return "", ""
global past_key_values
global model_vmfb_key
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")
elif "rocm" in device:
# add iree rocm flags
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")
prompt = create_prompt(model_name, history, prompt_prefix)
partial_text = ""
history[-1][-1] = "Getting the model ready... Done"
yield history, ""
history[-1][-1] = ""
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
total_time = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
partial_text += text + " "
history[-1][1] = partial_text
for text, exec_time in language_model.chat(history):
history[-1][-1] = text
if is_first:
prefill_time = exec_time
is_first = False
yield history, f"Prefill: {prefill_time:.2f}"
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
return history, ""
total_time += exec_time
token_count += 1
tokens_per_sec = token_count / total_time
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
def llm_chat_api(InputData: dict):
@@ -297,17 +90,11 @@ def llm_chat_api(InputData: dict):
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = (
InputData["model"] if "model" in InputData.keys() else "codegen"
)
model_name = InputData["model"] if "model" in InputData.keys() else "codegen"
model_path = llm_model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = (
None
if "max_tokens" not in InputData.keys()
else InputData["max_tokens"]
)
max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"]
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
@@ -344,9 +131,7 @@ def llm_chat_api(InputData: dict):
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(
model_name, [(InputData["messages"][0]["content"], "")]
)
prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")])
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
@@ -379,9 +164,7 @@ def llm_chat_api(InputData: dict):
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion"
if is_chat_completion_api
else "text_completion",
"object": "chat.completion" if is_chat_completion_api else "text_completion",
"created": int(end_time),
"choices": choices,
}
@@ -457,9 +240,7 @@ with gr.Blocks(title="Chat") as chat_element:
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(

View File

@@ -36,9 +36,7 @@ def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir):
metrics[val] = line.split(" ")[-1].strip("\n")
metrics["Average step"] = metrics["Average step"].strip("ms/it")
metrics["Total image generation"] = metrics[
"Total image generation"
].strip("sec")
metrics["Total image generation"] = metrics["Total image generation"].strip("sec")
metrics["device"] = device
metrics["use_tune"] = use_tune
metrics["model_name"] = model_name
@@ -84,10 +82,14 @@ def test_loop(
]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
inpaint_prompt_text = (
"--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
)
if os.name == "nt":
prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
inpaint_prompt_text = (
'--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
)
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
@@ -174,9 +176,7 @@ def test_loop(
)
print(command)
print("Successfully generated image")
os.makedirs(
"./test_images/golden/" + model_name, exist_ok=True
)
os.makedirs("./test_images/golden/" + model_name, exist_ok=True)
download_public_file(
"gs://shark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
@@ -191,14 +191,10 @@ def test_loop(
)
test_file = glob(test_file_path)[0]
golden_path = (
"./test_images/golden/" + model_name + "/*.png"
)
golden_path = "./test_images/golden/" + model_name + "/*.png"
golden_file = glob(golden_path)[0]
try:
compare_images(
test_file, golden_file, upload=upload_bool
)
compare_images(test_file, golden_file, upload=upload_bool)
except AssertionError as e:
print(e)
if exit_on_fail == True:
@@ -267,9 +263,7 @@ parser.add_argument(
parser.add_argument(
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument("-g", "--gen", action=argparse.BooleanOptionalAction, default=False)
if __name__ == "__main__":
args = parser.parse_args()

View File

@@ -10,9 +10,7 @@ from utils import get_datasets
shark_root = Path(__file__).parent.parent
demo_css = shark_root.joinpath("web/demo.css").resolve()
nodlogo_loc = shark_root.joinpath(
"web/models/stable_diffusion/logos/nod-logo.png"
)
nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/nod-logo.png")
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
@@ -76,9 +74,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
for line in reader.iter(type=dict, skip_invalid=True):
prompt_data[line["file_name"]] = (
[line["text"]]
if type(line["text"]) is str
else line["text"]
[line["text"]] if type(line["text"]) is str else line["text"]
)
return gr.Dropdown.update(choices=images[dataset])
@@ -104,9 +100,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
prompt_data[image_name] = []
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Image.update(value=img), gr.Dropdown.update(
choices=prompt_choices
)
return gr.Image.update(value=img), gr.Dropdown.update(choices=prompt_choices)
image_name.change(
fn=display_image,
@@ -123,12 +117,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt)
def save_prompt(dataset, image_name, prompts, prompt):
if (
dataset is None
or image_name is None
or prompts is None
or prompt is None
):
if dataset is None or image_name is None or prompts is None or prompt is None:
return
if prompts == "Add new":
@@ -137,9 +126,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
idx = prompt_data[image_name].index(prompts)
prompt_data[image_name][idx] = prompt
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
@@ -166,9 +153,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
return
prompt_data[image_name].remove(prompts)
prompt_path = (
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
)
prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
@@ -231,9 +216,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
# upload prompt and remove local data
dataset_path = str(shark_root) + "/dataset/" + dataset
dataset_gs_path = args.gs_url + "/" + dataset + "/"
os.system(
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
)
os.system(f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"')
os.system(f'rm -rf "{dataset_path}"')
return gr.Dropdown.update(value=None)

View File

@@ -8,8 +8,7 @@ from pathlib import Path
# Temporary workaround for transformers/__init__.py.
path_to_transformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
get_python_lib() + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
if path_to_transformers_hook.is_file():
pass
@@ -59,9 +58,7 @@ for line in fileinput.input(path_to_lazy_loader, inplace=True):
# For getting around timm's packaging.
# Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505
path_to_timm_activations = Path(
get_python_lib() + "/timm/layers/activations_jit.py"
)
path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py")
for line in fileinput.input(path_to_timm_activations, inplace=True):
if "@torch.jit.script" in line:
print("@torch.jit._script_if_tracing", end="\n")

View File

@@ -5,14 +5,25 @@ requires = [
"packaging",
"numpy>=1.22.4",
"torch-mlir>=20230620.875",
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
]
build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = "apps/language_models/scripts/vicuna.py"
extend-exclude = "apps/language_models/src/pipelines/minigpt4_pipeline.py"
exclude = '''
(
/(
| apps/stable_diffusion
| apps/language_models
| shark
| benchmarks
| tank
| build
| generated_imgs
| shark.venv
)/
| setup.py
)
'''

View File

@@ -1,3 +1,3 @@
[pytest]
addopts = --verbose -s -p no:warnings
norecursedirs = inference tank/tflite examples benchmarks shark
norecursedirs = inference tank/tflite examples benchmarks shark apps/shark_studio

View File

@@ -1,9 +1,13 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
-f https://openxla.github.io/iree/pip-release-links.html
--pre
setuptools
wheel
shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine#egg=turbine-models&subdirectory=python/turbine_models
# SHARK Runner
tqdm
@@ -17,11 +21,7 @@ pytest-forked
Pillow
parameterized
#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main
# Add transformers, diffusers and scipy since it most commonly used
tokenizers==0.13.3
transformers
diffusers
#accelerate is now required for diffusers import from ckpt.
accelerate
scipy
@@ -49,9 +49,6 @@ pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
# For quantized GPTQ models
optimum
auto_gptq

View File

@@ -44,14 +44,10 @@ def upscaler_test(verbose=False):
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(
f"[upscaler] response from server was : {res.status_code} {res.reason}"
)
print(f"[upscaler] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n")
def img2img_test(verbose=False):
@@ -96,14 +92,10 @@ def img2img_test(verbose=False):
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(
f"[img2img] response from server was : {res.status_code} {res.reason}"
)
print(f"[img2img] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n")
# NOTE Uncomment below to save the picture
@@ -133,13 +125,9 @@ def inpainting_test(verbose=False):
image_path = r"./rest_api_tests/dog.png"
img_file = open(image_path, "rb")
image = (
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
)
image = "data:image/png;base64," + base64.b64encode(img_file.read()).decode()
img_file = open(image_path, "rb")
mask = (
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
)
mask = "data:image/png;base64," + base64.b64encode(img_file.read()).decode()
url = "http://127.0.0.1:8080/sdapi/v1/inpaint"
@@ -166,14 +154,10 @@ def inpainting_test(verbose=False):
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(
f"[inpaint] response from server was : {res.status_code} {res.reason}"
)
print(f"[inpaint] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n")
def outpainting_test(verbose=False):
@@ -223,14 +207,10 @@ def outpainting_test(verbose=False):
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(
f"[outpaint] response from server was : {res.status_code} {res.reason}"
)
print(f"[outpaint] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n")
def txt2img_test(verbose=False):
@@ -262,14 +242,10 @@ def txt2img_test(verbose=False):
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(
f"[txt2img] response from server was : {res.status_code} {res.reason}"
)
print(f"[txt2img] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n")
def sd_models_test(verbose=False):
@@ -283,9 +259,7 @@ def sd_models_test(verbose=False):
res = requests.get(url=url, headers=headers, timeout=1000)
print(
f"[sd_models] response from server was : {res.status_code} {res.reason}"
)
print(f"[sd_models] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
@@ -302,9 +276,7 @@ def sd_samplers_test(verbose=False):
res = requests.get(url=url, headers=headers, timeout=1000)
print(
f"[sd_samplers] response from server was : {res.status_code} {res.reason}"
)
print(f"[sd_samplers] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
@@ -321,9 +293,7 @@ def options_test(verbose=False):
res = requests.get(url=url, headers=headers, timeout=1000)
print(
f"[options] response from server was : {res.status_code} {res.reason}"
)
print(f"[options] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
@@ -340,9 +310,7 @@ def cmd_flags_test(verbose=False):
res = requests.get(url=url, headers=headers, timeout=1000)
print(
f"[cmd-flags] response from server was : {res.status_code} {res.reason}"
)
print(f"[cmd-flags] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")

View File

@@ -9,11 +9,6 @@ with open("README.md", "r", encoding="utf-8") as fh:
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
backend_deps = []
if "NO_BACKEND" in os.environ.keys():
backend_deps = [
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
]
setup(
name="nodai-SHARK",
@@ -39,7 +34,5 @@ setup(
install_requires=[
"numpy",
"PyYAML",
"torch-mlir",
]
+ backend_deps,
)

View File

@@ -305,6 +305,7 @@ def compile_module_to_flatbuffer(
model_name="None",
debug=False,
compile_str=False,
write_to=None,
):
# Setup Compile arguments wrt to frontends.
input_type = "auto"
@@ -342,12 +343,24 @@ def compile_module_to_flatbuffer(
extra_args=args,
)
if write_to is not None:
with open(write_to, "wb") as f:
f.write(flatbuffer_blob)
return None
return flatbuffer_blob
def get_iree_module(
flatbuffer_blob, device, device_idx=None, rt_flags: list = []
flatbuffer_blob,
device,
device_idx=None,
rt_flags: list = [],
external_weight_file=None,
):
if external_weight_file is not None:
index = ireert.ParameterIndex()
index.load(external_weight_file)
# Returns the compiled module and the configs.
for flag in rt_flags:
ireert.flags.parse_flag(flag)
@@ -369,7 +382,10 @@ def get_iree_module(
vm_module = ireert.VmModule.from_buffer(
config.vm_instance, flatbuffer_blob, warn_if_copy=False
)
ctx = ireert.SystemContext(config=config)
modules = []
if external_weight_file is not None:
modules.append(index.create_provider(scope="model"))
ctx = ireert.SystemContext(vm_modules=modules, config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = getattr(ctx.modules, vm_module.name)
return ModuleCompiled, config
@@ -380,6 +396,7 @@ def load_vmfb_using_mmap(
device: str,
device_idx: int = None,
rt_flags: list = [],
external_weight_file: str = None,
):
print(f"Loading module {flatbuffer_blob_or_path}...")
if "task" in device:
@@ -440,17 +457,28 @@ def load_vmfb_using_mmap(
mmaped_vmfb = ireert.VmModule.mmap(
config.vm_instance, flatbuffer_blob_or_path
)
vm_modules = []
if external_weight_file is not None:
index = ireert.ParameterIndex()
index.load(external_weight_file)
param_module = ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
)
vm_modules.append(param_module)
vm_modules.append(mmaped_vmfb)
vm_modules.append(
ireert.create_hal_module(config.vm_instance, config.device)
)
dl.log(f"mmap {flatbuffer_blob_or_path}")
ctx = ireert.SystemContext(config=config)
for flag in shark_args.additional_runtime_args:
ireert.flags.parse_flags(flag)
dl.log(f"ireert.SystemContext created")
if "vulkan" in device:
# Vulkan pipeline creation consumes significant amount of time.
print(
"\tCompiling Vulkan shaders. This may take a few minutes."
)
ctx.add_vm_module(mmaped_vmfb)
ctx = ireert.SystemContext(config=config, vm_modules=vm_modules)
dl.log(f"ireert.SystemContext created")
for flag in shark_args.additional_runtime_args:
ireert.flags.parse_flags(flag)
dl.log(f"module initialized")
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
@@ -475,6 +503,8 @@ def get_iree_compiled_module(
mmap: bool = False,
debug: bool = False,
compile_str: bool = False,
external_weight_file: str = None,
write_to: bool = None,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
@@ -485,6 +515,7 @@ def get_iree_compiled_module(
extra_args=extra_args,
debug=debug,
compile_str=compile_str,
write_to=write_to,
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
@@ -492,8 +523,14 @@ def get_iree_compiled_module(
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
if write_to is not None:
flatbuffer_blob = write_to
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx, rt_flags
flatbuffer_blob,
device,
device_idx,
rt_flags,
external_weight_file=external_weight_file,
)
else:
vmfb, config = get_iree_module(
@@ -501,6 +538,7 @@ def get_iree_compiled_module(
device,
device_idx=device_idx,
rt_flags=rt_flags,
external_weight_file=external_weight_file,
)
ret_params = {
"vmfb": vmfb,