Compare commits

..

8 Commits

Author SHA1 Message Date
Ean Garvey
8478106f73 Remove old nodlogo import. 2024-05-30 11:37:20 -05:00
Ean Garvey
a57db9b12c Merge branch 'main' into ean-igpu 2024-05-30 11:19:20 -05:00
Ean Garvey
d8bb4db800 formatting 2024-05-30 03:46:47 -05:00
Ean Garvey
9b916fe901 Merge branch 'ean-igpu' of https://www.github.com/nod-ai/SHARK into ean-igpu 2024-05-30 03:45:45 -05:00
Ean Garvey
e6b4885b0e custom pipe loading 2024-05-30 03:45:18 -05:00
Ean Garvey
10a2a43d70 Merge branch 'main' into ean-igpu 2024-05-30 01:09:04 -05:00
Ean Garvey
c7bd54de23 Small fixes to igpu, SDXL-turbo 2024-05-30 00:54:12 -05:00
Ean Garvey
1adfb3b24b Add igpu and custom triple support. 2024-05-29 13:59:13 -05:00
9 changed files with 111 additions and 232 deletions

View File

@@ -81,5 +81,4 @@ jobs:
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
# Disabled due to hang when exporting test llama2
# python apps/shark_studio/tests/api_test.py
python apps/shark_studio/tests/api_test.py

View File

@@ -3,13 +3,8 @@ from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import (
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.web.utils.file_utils import get_resource_path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from urllib.request import urlopen
import iree.runtime as ireert
from itertools import chain
import gc
@@ -70,7 +65,6 @@ class LanguageModel:
use_system_prompt=True,
streaming_llm=False,
):
_, _, self.triple = parse_device(device)
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.device = device.split("=>")[-1].strip()
self.backend = self.device.split("://")[0]
@@ -171,7 +165,6 @@ class LanguageModel:
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
@@ -201,27 +194,11 @@ class LanguageModel:
)
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
elif self.backend == "rocm":
flags.extend(
[
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
]
)
if "gfx9" in self.triple:
flags.extend(
[
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
]
)
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
device=self.device,
frontend="auto",
frontend="torch",
model_config_path=None,
extra_args=flags,
write_to=self.vmfb_name,
@@ -352,17 +329,6 @@ class LanguageModel:
return result_output, total_time
def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
return spec_path
with open(spec_path, "w") as f:
f.write(attn_spec)
return spec_path
def llm_chat_api(InputData: dict):
from datetime import datetime as dt

View File

@@ -1,6 +1,5 @@
import gc
import torch
import gradio as gr
import time
import os
import json
@@ -105,7 +104,7 @@ class StableDiffusion:
self.base_model_id = base_model_id
self.custom_vae = custom_vae
self.is_sdxl = "xl" in self.base_model_id.lower()
self.is_custom = ".py" in self.base_model_id.lower()
self.is_custom = "custom" in self.base_model_id.lower()
if self.is_custom:
custom_module = load_script(
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
@@ -113,7 +112,8 @@ class StableDiffusion:
)
self.turbine_pipe = custom_module.StudioPipeline
self.model_map = custom_module.MODEL_MAP
elif self.is_sdxl:
if self.is_sdxl:
self.turbine_pipe = SharkSDXLPipeline
self.model_map = EMPTY_SDXL_MAP
else:
@@ -181,17 +181,12 @@ class StableDiffusion:
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
gc.collect()
def prepare_pipe(
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
):
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
weights = copy.deepcopy(self.model_map)
if not self.is_sdxl:
compiled_pipeline = False
self.compiled_pipeline = compiled_pipeline
if custom_weights:
custom_weights = os.path.join(
@@ -258,6 +253,7 @@ class StableDiffusion:
guidance_scale,
seed,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
hints,
@@ -276,7 +272,7 @@ class StableDiffusion:
def shark_sd_fn_dict_input(
sd_kwargs: dict,
):
print("\n[LOG] Submitting Request...")
print("[LOG] Submitting Request...")
for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
@@ -286,34 +282,9 @@ def shark_sd_fn_dict_input(
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs
for i in range(1):
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
yield generated_imgs
def shark_sd_fn(
@@ -336,7 +307,7 @@ def shark_sd_fn(
device: str,
target_triple: str,
ondemand: bool,
compiled_pipeline: bool,
repeatable_seeds: bool,
resample_type: str,
controlnets: dict,
embeddings: dict,
@@ -399,7 +370,6 @@ def shark_sd_fn(
"adapters": adapters,
"embeddings": embeddings,
"is_img2img": is_img2img,
"compiled_pipeline": compiled_pipeline,
}
submit_run_kwargs = {
"prompt": prompt,
@@ -409,6 +379,7 @@ def shark_sd_fn(
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,
@@ -441,35 +412,22 @@ def shark_sd_fn(
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
if not isinstance(out_imgs, list):
out_imgs = [out_imgs]
# total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
# print(f"\n[LOG] {text_output}")
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
for batch in range(batch_size):
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
save_output_img(
out_imgs[current_batch],
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
# TODO: make seed changes over batch counts more configurable.
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
return (generated_imgs, "")
def unload_sd():
print("Unloading models.")
import apps.shark_studio.web.utils.globals as global_obj
global_obj.clear_cache()
gc.collect()
return generated_imgs, ""
def cancel_sd():

View File

@@ -0,0 +1,28 @@
{
"prompt": [
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "AMD Radeon RX 7900 XTX => vulkan://0",
"ondemand": false,
"repeatable_seeds": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}

View File

@@ -138,7 +138,6 @@ with gr.Blocks(title="Chat") as chat_element:
label="Run in streaming mode (requires recompilation)",
value=True,
interactive=False,
visible=False,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",

View File

@@ -14,12 +14,11 @@ from apps.shark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_checkpoints,
get_configs_path,
write_default_sd_configs,
write_default_sd_config,
)
from apps.shark_studio.api.sd import (
shark_sd_fn_dict_input,
cancel_sd,
unload_sd,
)
from apps.shark_studio.api.controlnet import (
cnet_preview,
@@ -120,7 +119,7 @@ def pull_sd_configs(
device,
target_triple,
ondemand,
compiled_pipeline,
repeatable_seeds,
resample_type,
controlnets,
embeddings,
@@ -179,7 +178,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["compiled_pipeline"],
sd_json["repeatable_seeds"],
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
@@ -257,7 +256,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
allow_custom_value=False,
)
target_triple = gr.Textbox(
elem_id="target_triple",
elem_id="triple",
label="Architecture",
value="",
)
@@ -588,6 +587,21 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
object_fit="fit",
preview=True,
)
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
with gr.Row():
batch_count = gr.Slider(
1,
@@ -606,15 +620,17 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
interactive=True,
visible=True,
)
compiled_pipeline = gr.Checkbox(
False,
label="Faster txt2img (SDXL only)",
repeatable_seeds = gr.Checkbox(
cmd_opts.repeatable_seeds,
label="Use Repeatable Seeds for Batches",
)
with gr.Row():
stable_diffusion = gr.Button("Start")
unload = gr.Button("Unload Models")
unload.click(
fn=unload_sd,
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
show_progress=False,
)
@@ -629,7 +645,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_configs(get_configs_path())
write_default_sd_config(default_config_file)
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
@@ -685,7 +701,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
device,
target_triple,
ondemand,
compiled_pipeline,
repeatable_seeds,
resample_type,
cnet_config,
embeddings_config,
@@ -702,22 +718,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
pull_kwargs = dict(
fn=pull_sd_configs,
@@ -741,7 +741,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
device,
target_triple,
ondemand,
compiled_pipeline,
repeatable_seeds,
resample_type,
cnet_config,
embeddings_config,

View File

@@ -1,95 +0,0 @@
default_sd_config = r"""{
"prompt": [
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_30steps = r"""{
"prompt": [
"a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 1024,
"width": 1024,
"steps": 30,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-xl-base-1.0",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_turbo = r"""{
"prompt": [
"A cat wearing a hat that says 'TURBO' on it. The cat is sitting on a skateboard."
],
"negative_prompt": [
""
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 2,
"strength": 0.8,
"guidance_scale": 0,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerAncestralDiscrete",
"base_model_id": "stabilityai/sdxl-turbo",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
default_sd_configs = {
"default_sd_config.json": default_sd_config,
"sdxl-30steps.json": sdxl_30steps,
"sdxl-turbo.json": sdxl_turbo,
}

View File

@@ -11,14 +11,39 @@ checkpoints_filetypes = (
"*.safetensors",
)
from apps.shark_studio.web.utils.default_configs import default_sd_configs
default_sd_config = r"""{
"prompt": [
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "AMD Radeon RX 7900 XTX => vulkan://0",
"ondemand": false,
"repeatable_seeds": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
def write_default_sd_configs(path):
for key in default_sd_configs.keys():
config_fpath = os.path.join(path, key)
with open(config_fpath, "w") as f:
f.write(default_sd_configs[key])
def write_default_sd_config(path):
with open(path, "w") as f:
f.write(default_sd_config)
def safe_name(name):

View File

@@ -8,9 +8,8 @@ wheel
torch==2.3.0
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@deprecated-constraints#subdirectory=models
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
# SHARK Runner
tqdm