mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
5 Commits
20240419.1
...
20240525.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd07cae991 | ||
|
|
6cb86a843e | ||
|
|
7db1612a5c | ||
|
|
81d6e059ac | ||
|
|
e003d0abe8 |
1
.github/workflows/nightly.yml
vendored
1
.github/workflows/nightly.yml
vendored
@@ -53,6 +53,7 @@ jobs:
|
||||
python process_skipfiles.py
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip install -e .
|
||||
pip freeze -l
|
||||
pyinstaller .\apps\shark_studio\shark_studio.spec
|
||||
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
@@ -53,11 +53,11 @@ def initialize():
|
||||
clear_tmp_imgs()
|
||||
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
create_checkpoint_folders,
|
||||
create_model_folders,
|
||||
)
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
create_checkpoint_folders()
|
||||
create_model_folders()
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
llm_model_map = {
|
||||
"llama2_7b": {
|
||||
"meta-llama/Llama-2-7b-chat-hf": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
|
||||
@@ -258,7 +258,8 @@ class LanguageModel:
|
||||
|
||||
history.append(format_out(token))
|
||||
while (
|
||||
format_out(token) != llm_model_map["llama2_7b"]["stop_token"]
|
||||
format_out(token)
|
||||
!= llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
|
||||
and len(history) < self.max_tokens
|
||||
):
|
||||
dec_time = time.time()
|
||||
@@ -272,7 +273,10 @@ class LanguageModel:
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
if (
|
||||
format_out(token)
|
||||
== llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
|
||||
):
|
||||
break
|
||||
|
||||
for i in range(len(history)):
|
||||
@@ -306,7 +310,7 @@ class LanguageModel:
|
||||
self.first_input = False
|
||||
|
||||
history.append(int(token))
|
||||
while token != llm_model_map["llama2_7b"]["stop_token"]:
|
||||
while token != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
|
||||
dec_time = time.time()
|
||||
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
|
||||
history.append(int(token))
|
||||
@@ -317,7 +321,7 @@ class LanguageModel:
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if token == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
if token == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
|
||||
break
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
@@ -347,7 +351,11 @@ def llm_chat_api(InputData: dict):
|
||||
else:
|
||||
print(f"prompt : {InputData['prompt']}")
|
||||
|
||||
model_name = InputData["model"] if "model" in InputData.keys() else "llama2_7b"
|
||||
model_name = (
|
||||
InputData["model"]
|
||||
if "model" in InputData.keys()
|
||||
else "meta-llama/Llama-2-7b-chat-hf"
|
||||
)
|
||||
model_path = llm_model_map[model_name]
|
||||
device = InputData["device"] if "device" in InputData.keys() else "cpu"
|
||||
precision = "fp16"
|
||||
|
||||
@@ -602,7 +602,9 @@ if __name__ == "__main__":
|
||||
|
||||
global_obj._init()
|
||||
|
||||
sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
|
||||
sd_json = view_json_file(
|
||||
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
|
||||
)
|
||||
sd_kwargs = json.loads(sd_json)
|
||||
for arg in vars(cmd_opts):
|
||||
if arg in sd_kwargs:
|
||||
|
||||
@@ -6,6 +6,7 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
@@ -87,6 +88,7 @@ def process_custom_pipe_weights(custom_weights):
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
custom_weights_params = custom_weights
|
||||
|
||||
return custom_weights_params, custom_weights_tgt
|
||||
|
||||
|
||||
@@ -98,7 +100,7 @@ def get_civitai_checkpoint(url: str):
|
||||
base_filename = re.findall(
|
||||
'"([^"]*)"', response.headers["Content-Disposition"]
|
||||
)[0]
|
||||
destination_path = Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename
|
||||
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
|
||||
|
||||
# we don't have this model downloaded yet
|
||||
if not destination_path.is_file():
|
||||
|
||||
@@ -41,7 +41,7 @@ class SharkPipelineBase:
|
||||
self.device, self.device_id = clean_device_info(device)
|
||||
self.import_mlir = import_mlir
|
||||
self.iree_module_dict = {}
|
||||
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
|
||||
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.mkdir(self.tmp_dir)
|
||||
self.tempfiles = {}
|
||||
@@ -55,9 +55,7 @@ class SharkPipelineBase:
|
||||
# and your model map is populated with any IR - unique model IDs and their static params,
|
||||
# call this method to get the artifacts associated with your map.
|
||||
self.pipe_id = self.safe_name(pipe_id)
|
||||
self.pipe_vmfb_path = Path(
|
||||
os.path.join(get_checkpoints_path(".."), self.pipe_id)
|
||||
)
|
||||
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
|
||||
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
|
||||
if submodel == "None":
|
||||
print("\n[LOG] Gathering any pre-compiled artifacts....")
|
||||
|
||||
@@ -339,7 +339,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
default=os.path.join(os.getcwd(), "generated_imgs"),
|
||||
help="Directory path to save the output images and json.",
|
||||
)
|
||||
|
||||
@@ -613,12 +613,27 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_dir",
|
||||
"--tmp_dir",
|
||||
type=str,
|
||||
default="../models",
|
||||
default=os.path.join(os.getcwd(), "shark_tmp"),
|
||||
help="Path to tmp directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--config_dir",
|
||||
type=str,
|
||||
default=os.path.join(os.getcwd(), "configs"),
|
||||
help="Path to config directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default=os.path.join(os.getcwd(), "models"),
|
||||
help="Path to directory where all .ckpts are stored in order to populate "
|
||||
"them in the web UI.",
|
||||
)
|
||||
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
"--ui",
|
||||
|
||||
@@ -9,6 +9,7 @@ from apps.shark_studio.api.llm import (
|
||||
llm_model_map,
|
||||
LanguageModel,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
@@ -64,6 +65,7 @@ def chat_fn(
|
||||
external_weights="safetensors",
|
||||
use_system_prompt=prompt_prefix,
|
||||
streaming_llm=streaming_llm,
|
||||
hf_auth_token=cmd_opts.hf_auth_token,
|
||||
)
|
||||
history[-1][-1] = "Getting the model ready... Done"
|
||||
yield history, ""
|
||||
|
||||
@@ -231,9 +231,14 @@ def import_original(original_img, width, height):
|
||||
|
||||
|
||||
def base_model_changed(base_model_id):
|
||||
new_choices = get_checkpoints(
|
||||
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
|
||||
) + get_checkpoints(model_type="checkpoints")
|
||||
ckpt_path = Path(
|
||||
os.path.join(
|
||||
cmd_opts.model_dir, "checkpoints", os.path.basename(str(base_model_id))
|
||||
)
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
new_choices = get_checkpoints(ckpt_path) + get_checkpoints(model_type="checkpoints")
|
||||
|
||||
return gr.Dropdown(
|
||||
value=new_choices[0] if len(new_choices) > 0 else "None",
|
||||
@@ -581,21 +586,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
object_fit="fit",
|
||||
preview=True,
|
||||
)
|
||||
with gr.Row():
|
||||
std_output = gr.Textbox(
|
||||
value=f"{sd_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=True,
|
||||
label="Log",
|
||||
show_copy_button=True,
|
||||
)
|
||||
sd_element.load(
|
||||
logger.read_sd_logs, None, std_output, every=1
|
||||
)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
@@ -631,19 +621,18 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
stop_batch = gr.Button("Stop")
|
||||
with gr.Tab(label="Config", id=102) as sd_tab_config:
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
Path(get_configs_path()).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
default_config_file = os.path.join(
|
||||
get_configs_path(),
|
||||
"default_sd_config.json",
|
||||
)
|
||||
write_default_sd_config(default_config_file)
|
||||
sd_json = gr.JSON(
|
||||
elem_classes=["fill"],
|
||||
value=view_json_file(default_config_file),
|
||||
)
|
||||
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
|
||||
default_config_file = os.path.join(
|
||||
get_configs_path(),
|
||||
"default_sd_config.json",
|
||||
)
|
||||
write_default_sd_config(default_config_file)
|
||||
sd_json = gr.JSON(
|
||||
label="SD Config",
|
||||
elem_classes=["fill"],
|
||||
value=view_json_file(default_config_file),
|
||||
render=False,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
load_sd_config = gr.FileExplorer(
|
||||
@@ -706,11 +695,30 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
sd_json.render()
|
||||
save_sd_config.click(
|
||||
fn=save_sd_cfg,
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
with gr.Tab(label="Log", id=103) as sd_tab_log:
|
||||
with gr.Row():
|
||||
std_output = gr.Textbox(
|
||||
value=f"{sd_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
elem_id="std_output",
|
||||
show_label=True,
|
||||
label="Log",
|
||||
show_copy_button=True,
|
||||
)
|
||||
sd_element.load(
|
||||
logger.read_sd_logs, None, std_output, every=1
|
||||
)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
with gr.Tab(label="Automation", id=104) as sd_tab_automation:
|
||||
pass
|
||||
|
||||
pull_kwargs = dict(
|
||||
fn=pull_sd_configs,
|
||||
|
||||
@@ -66,33 +66,39 @@ def get_resource_path(path):
|
||||
|
||||
|
||||
def get_configs_path() -> Path:
|
||||
configs = get_resource_path(os.path.join("..", "configs"))
|
||||
configs = get_resource_path(cmd_opts.config_dir)
|
||||
if not os.path.exists(configs):
|
||||
os.mkdir(configs)
|
||||
return Path(get_resource_path("../configs"))
|
||||
return Path(configs)
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
return Path(
|
||||
cmd_opts.output_dir
|
||||
if cmd_opts.output_dir
|
||||
else get_resource_path("../generated_imgs")
|
||||
)
|
||||
outputs = get_resource_path(cmd_opts.output_dir)
|
||||
if not os.path.exists(outputs):
|
||||
os.mkdir(outputs)
|
||||
return Path(outputs)
|
||||
|
||||
|
||||
def get_tmp_path() -> Path:
|
||||
tmpdir = get_resource_path(cmd_opts.model_dir)
|
||||
if not os.path.exists(tmpdir):
|
||||
os.mkdir(tmpdir)
|
||||
return Path(tmpdir)
|
||||
|
||||
|
||||
def get_generated_imgs_todays_subdir() -> str:
|
||||
return dt.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
def create_checkpoint_folders():
|
||||
def create_model_folders():
|
||||
dir = ["checkpoints", "vae", "lora", "vmfb"]
|
||||
if not os.path.isdir(cmd_opts.ckpt_dir):
|
||||
if not os.path.isdir(cmd_opts.model_dir):
|
||||
try:
|
||||
os.makedirs(cmd_opts.ckpt_dir)
|
||||
os.makedirs(cmd_opts.model_dir)
|
||||
except OSError:
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, "
|
||||
f"{cmd_opts.ckpt_dir} folder does not exist, and cannot be created."
|
||||
f"Invalid --model_dir argument, "
|
||||
f"{cmd_opts.model_dir} folder does not exist, and cannot be created."
|
||||
)
|
||||
|
||||
for root in dir:
|
||||
@@ -100,7 +106,7 @@ def create_checkpoint_folders():
|
||||
|
||||
|
||||
def get_checkpoints_path(model_type=""):
|
||||
return get_resource_path(os.path.join(cmd_opts.ckpt_dir, model_type))
|
||||
return get_resource_path(os.path.join(cmd_opts.model_dir, model_type))
|
||||
|
||||
|
||||
def get_checkpoints(model_type="checkpoints"):
|
||||
|
||||
@@ -2,7 +2,9 @@ import os
|
||||
import shutil
|
||||
from time import time
|
||||
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
def clear_tmp_mlir():
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
-f https://openxla.github.io/iree/pip-release-links.html
|
||||
-f https://iree.dev/pip-release-links.html
|
||||
--pre
|
||||
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
torch==2.3.0.dev20240305
|
||||
shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=core
|
||||
torch==2.3.0
|
||||
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=models
|
||||
|
||||
# SHARK Runner
|
||||
@@ -35,6 +35,7 @@ safetensors==0.3.1
|
||||
py-cpuinfo
|
||||
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
|
||||
mpmath==1.3.0
|
||||
optimum
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
@@ -89,7 +89,4 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
# remove this when windows DLL issues are fixed from LLVM changes
|
||||
pip install --force-reinstall https://github.com/openxla/iree/releases/download/candidate-20240326.843/iree_compiler-20240326.843-cp311-cp311-win_amd64.whl https://github.com/openxla/iree/releases/download/candidate-20240326.843/iree_runtime-20240326.843-cp311-cp311-win_amd64.whl
|
||||
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
@@ -6,6 +6,7 @@ import tempfile
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
def create_hash(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
@@ -120,7 +121,7 @@ class SharkImporter:
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
save_dir="./shark_tmp/",
|
||||
save_dir=cmd_opts.tmp_dir, #"./shark_tmp/",
|
||||
mlir_type="linalg",
|
||||
):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
@@ -806,7 +807,7 @@ def save_mlir(
|
||||
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
|
||||
)
|
||||
if dir == "":
|
||||
dir = os.path.join(".", "shark_tmp")
|
||||
dir = cmd_opts.tmp_dir, #os.path.join(".", "shark_tmp")
|
||||
mlir_path = os.path.join(dir, model_name_mlir)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
if not os.path.exists(dir):
|
||||
|
||||
Reference in New Issue
Block a user