mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 14:58:11 -05:00
Compare commits
7 Commits
gaurav/amd
...
debug
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4529fd0461 | ||
|
|
4c2bb4b7b4 | ||
|
|
d5013fd13e | ||
|
|
26f80ccbbb | ||
|
|
d2c3752dc7 | ||
|
|
4505c4549f | ||
|
|
793495c9c6 |
3
.github/workflows/test-studio.yml
vendored
3
.github/workflows/test-studio.yml
vendored
@@ -81,4 +81,5 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
pip install -r requirements.txt --no-cache-dir
|
||||
pip install -e .
|
||||
python apps/shark_studio/tests/api_test.py
|
||||
# Disabled due to hang when exporting test llama2
|
||||
# python apps/shark_studio/tests/api_test.py
|
||||
|
||||
@@ -3,8 +3,13 @@ from turbine_models.model_runner import vmfbRunner
|
||||
from turbine_models.gen_external_params.gen_external_params import gen_external_params
|
||||
import time
|
||||
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
|
||||
from apps.shark_studio.web.utils.file_utils import get_resource_path
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
get_resource_path,
|
||||
get_checkpoints_path,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.api.utils import parse_device
|
||||
from urllib.request import urlopen
|
||||
import iree.runtime as ireert
|
||||
from itertools import chain
|
||||
import gc
|
||||
@@ -65,6 +70,7 @@ class LanguageModel:
|
||||
use_system_prompt=True,
|
||||
streaming_llm=False,
|
||||
):
|
||||
_, _, self.triple = parse_device(device)
|
||||
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
|
||||
self.device = device.split("=>")[-1].strip()
|
||||
self.backend = self.device.split("://")[0]
|
||||
@@ -165,6 +171,7 @@ class LanguageModel:
|
||||
precision=self.precision,
|
||||
quantization=self.quantization,
|
||||
streaming_llm=self.streaming_llm,
|
||||
decomp_attn=True,
|
||||
)
|
||||
with open(self.tempfile_name, "w+") as f:
|
||||
f.write(self.torch_ir)
|
||||
@@ -194,11 +201,27 @@ class LanguageModel:
|
||||
)
|
||||
elif self.backend == "vulkan":
|
||||
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
|
||||
elif self.backend == "rocm":
|
||||
flags.extend(
|
||||
[
|
||||
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
|
||||
"--iree-llvmgpu-enable-prefetch=true",
|
||||
"--iree-opt-outer-dim-concat=true",
|
||||
"--iree-flow-enable-aggressive-fusion",
|
||||
]
|
||||
)
|
||||
if "gfx9" in self.triple:
|
||||
flags.extend(
|
||||
[
|
||||
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
|
||||
"--iree-codegen-llvmgpu-use-vector-distribution=true",
|
||||
]
|
||||
)
|
||||
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
self.tempfile_name,
|
||||
device=self.device,
|
||||
frontend="torch",
|
||||
frontend="auto",
|
||||
model_config_path=None,
|
||||
extra_args=flags,
|
||||
write_to=self.vmfb_name,
|
||||
@@ -329,6 +352,17 @@ class LanguageModel:
|
||||
return result_output, total_time
|
||||
|
||||
|
||||
def get_mfma_spec_path(target_chip, save_dir):
|
||||
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
|
||||
attn_spec = urlopen(url).read().decode("utf-8")
|
||||
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
|
||||
if os.path.exists(spec_path):
|
||||
return spec_path
|
||||
with open(spec_path, "w") as f:
|
||||
f.write(attn_spec)
|
||||
return spec_path
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
from datetime import datetime as dt
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import gc
|
||||
import torch
|
||||
import gradio as gr
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import copy
|
||||
import importlib.util
|
||||
import sys
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
@@ -56,6 +59,23 @@ EMPTY_FLAGS = {
|
||||
}
|
||||
|
||||
|
||||
def load_script(source, module_name):
|
||||
"""
|
||||
reads file source and loads it as a module
|
||||
|
||||
:param source: file to load
|
||||
:param module_name: name of module to register in sys.modules
|
||||
:return: loaded module
|
||||
"""
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, source)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
# This class is responsible for executing image generation and creating
|
||||
# /managing a set of compiled modules to run Stable Diffusion. The init
|
||||
@@ -78,19 +98,27 @@ class StableDiffusion:
|
||||
num_loras: int = 0,
|
||||
import_ir: bool = True,
|
||||
is_controlled: bool = False,
|
||||
external_weights: str = "safetensors",
|
||||
):
|
||||
self.precision = precision
|
||||
self.compiled_pipeline = False
|
||||
self.base_model_id = base_model_id
|
||||
self.custom_vae = custom_vae
|
||||
self.is_sdxl = "xl" in self.base_model_id.lower()
|
||||
if self.is_sdxl:
|
||||
self.is_custom = ".py" in self.base_model_id.lower()
|
||||
if self.is_custom:
|
||||
custom_module = load_script(
|
||||
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
|
||||
"custom_pipeline",
|
||||
)
|
||||
self.turbine_pipe = custom_module.StudioPipeline
|
||||
self.model_map = custom_module.MODEL_MAP
|
||||
elif self.is_sdxl:
|
||||
self.turbine_pipe = SharkSDXLPipeline
|
||||
self.model_map = EMPTY_SDXL_MAP
|
||||
else:
|
||||
self.turbine_pipe = SharkSDPipeline
|
||||
self.model_map = EMPTY_SD_MAP
|
||||
external_weights = "safetensors"
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
||||
pipe_id_list = [
|
||||
@@ -122,9 +150,12 @@ class StableDiffusion:
|
||||
if triple in ["gfx940", "gfx942", "gfx90a"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "mfma"
|
||||
elif triple in ["gfx1100", "gfx1103"]:
|
||||
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "wmma"
|
||||
if triple in ["gfx1103", "gfx1150"]:
|
||||
# external weights have issues on igpu
|
||||
external_weights = None
|
||||
elif target_backend == "llvm-cpu":
|
||||
decomp_attn = False
|
||||
|
||||
@@ -150,12 +181,17 @@ class StableDiffusion:
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
|
||||
def prepare_pipe(
|
||||
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
|
||||
):
|
||||
print(f"\n[LOG] Preparing pipeline...")
|
||||
self.is_img2img = False
|
||||
mlirs = copy.deepcopy(self.model_map)
|
||||
vmfbs = copy.deepcopy(self.model_map)
|
||||
weights = copy.deepcopy(self.model_map)
|
||||
if not self.is_sdxl:
|
||||
compiled_pipeline = False
|
||||
self.compiled_pipeline = compiled_pipeline
|
||||
|
||||
if custom_weights:
|
||||
custom_weights = os.path.join(
|
||||
@@ -222,7 +258,6 @@ class StableDiffusion:
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
@@ -241,7 +276,7 @@ class StableDiffusion:
|
||||
def shark_sd_fn_dict_input(
|
||||
sd_kwargs: dict,
|
||||
):
|
||||
print("[LOG] Submitting Request...")
|
||||
print("\n[LOG] Submitting Request...")
|
||||
|
||||
for key in sd_kwargs:
|
||||
if sd_kwargs[key] in [None, []]:
|
||||
@@ -251,9 +286,34 @@ def shark_sd_fn_dict_input(
|
||||
if key == "seed":
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
|
||||
for i in range(1):
|
||||
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
||||
yield generated_imgs
|
||||
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
|
||||
if not sd_kwargs["device"]:
|
||||
gr.Warning("No device specified. Please specify a device.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] not in [512, 1024]:
|
||||
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] != sd_kwargs["width"]:
|
||||
gr.Warning("Height and width must be the same. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
|
||||
if sd_kwargs["steps"] > 10:
|
||||
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
|
||||
return None, ""
|
||||
if sd_kwargs["guidance_scale"] > 3:
|
||||
gr.Warning(
|
||||
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
|
||||
)
|
||||
return None, ""
|
||||
if sd_kwargs["target_triple"] == "":
|
||||
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
|
||||
gr.Warning(
|
||||
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
||||
)
|
||||
return None, ""
|
||||
|
||||
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
||||
return generated_imgs
|
||||
|
||||
|
||||
def shark_sd_fn(
|
||||
@@ -276,7 +336,7 @@ def shark_sd_fn(
|
||||
device: str,
|
||||
target_triple: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
compiled_pipeline: bool,
|
||||
resample_type: str,
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
@@ -286,8 +346,6 @@ def shark_sd_fn(
|
||||
sd_init_image = [sd_init_image]
|
||||
is_img2img = True if sd_init_image[0] is not None else False
|
||||
|
||||
print("\n[LOG] Performing Stable Diffusion Pipeline setup...")
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
@@ -341,6 +399,7 @@ def shark_sd_fn(
|
||||
"adapters": adapters,
|
||||
"embeddings": embeddings,
|
||||
"is_img2img": is_img2img,
|
||||
"compiled_pipeline": compiled_pipeline,
|
||||
}
|
||||
submit_run_kwargs = {
|
||||
"prompt": prompt,
|
||||
@@ -350,7 +409,6 @@ def shark_sd_fn(
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
"ondemand": ondemand,
|
||||
"repeatable_seeds": repeatable_seeds,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
@@ -383,22 +441,35 @@ def shark_sd_fn(
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
||||
if not isinstance(out_imgs, list):
|
||||
out_imgs = [out_imgs]
|
||||
# total_time = time.time() - start_time
|
||||
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
|
||||
# print(f"\n[LOG] {text_output}")
|
||||
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
# break
|
||||
# else:
|
||||
save_output_img(
|
||||
out_imgs[current_batch],
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
for batch in range(batch_size):
|
||||
save_output_img(
|
||||
out_imgs[batch],
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
# TODO: make seed changes over batch counts more configurable.
|
||||
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
|
||||
yield generated_imgs, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
return generated_imgs, ""
|
||||
return (generated_imgs, "")
|
||||
|
||||
|
||||
def unload_sd():
|
||||
print("Unloading models.")
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
|
||||
@@ -52,6 +52,13 @@ def get_available_devices():
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
@@ -64,20 +71,15 @@ def get_available_devices():
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
hip_devices = get_devices_by_name("hip")
|
||||
available_devices.extend(hip_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
print(available_devices)
|
||||
|
||||
for idx, device_str in enumerate(available_devices):
|
||||
if "AMD Radeon(TM) Graphics =>" in device_str:
|
||||
igpu_id_candidates = [
|
||||
@@ -87,10 +89,9 @@ def get_available_devices():
|
||||
]
|
||||
for igpu_name in igpu_id_candidates:
|
||||
if igpu_name:
|
||||
print(f"Found iGPU: {igpu_name} for {device_str}")
|
||||
available_devices[idx] = device_str.replace(
|
||||
"AMD Radeon(TM) Graphics", f"AMD iGPU: {igpu_name}"
|
||||
)
|
||||
available_devices[idx] = device_str.replace(
|
||||
"AMD Radeon(TM) Graphics", igpu_name
|
||||
)
|
||||
break
|
||||
return available_devices
|
||||
|
||||
|
||||
@@ -24,47 +24,47 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
)
|
||||
schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
)
|
||||
schedulers["DPMSolverMultistepKarras"] = (
|
||||
DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
)
|
||||
schedulers["DPMSolverMultistepKarras++"] = (
|
||||
DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
algorithm_type="dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
)
|
||||
# schedulers["DDPM"] = DDPMScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
# )
|
||||
# schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
# )
|
||||
# schedulers["DPMSolverMultistepKarras"] = (
|
||||
# DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# use_karras_sigmas=True,
|
||||
# )
|
||||
# )
|
||||
# schedulers["DPMSolverMultistepKarras++"] = (
|
||||
# DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# algorithm_type="dpmsolver++",
|
||||
# use_karras_sigmas=True,
|
||||
# )
|
||||
# )
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
@@ -75,24 +75,24 @@ def get_schedulers(model_id):
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2AncestralDiscrete"] = (
|
||||
KDPM2AncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
# schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["KDPM2AncestralDiscrete"] = (
|
||||
# KDPM2AncestralDiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# )
|
||||
# schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
return schedulers
|
||||
|
||||
|
||||
@@ -102,17 +102,17 @@ def export_scheduler_model(model):
|
||||
|
||||
scheduler_model_map = {
|
||||
"PNDM": export_scheduler_model("PNDMScheduler"),
|
||||
"DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
|
||||
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
|
||||
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
|
||||
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
|
||||
"LCM": export_scheduler_model("LCMScheduler"),
|
||||
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
|
||||
"DDPM": export_scheduler_model("DDPMScheduler"),
|
||||
"DDIM": export_scheduler_model("DDIMScheduler"),
|
||||
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
|
||||
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
|
||||
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
|
||||
"DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
|
||||
"KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
|
||||
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
|
||||
# "LCM": export_scheduler_model("LCMScheduler"),
|
||||
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
|
||||
# "DDPM": export_scheduler_model("DDPMScheduler"),
|
||||
# "DDIM": export_scheduler_model("DDIMScheduler"),
|
||||
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
|
||||
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
|
||||
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
|
||||
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
|
||||
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
|
||||
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
|
||||
}
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
{
|
||||
"prompt": [
|
||||
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
|
||||
],
|
||||
"negative_prompt": [
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
"sd_init_image": [null],
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"steps": 50,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": "-1",
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "EulerDiscrete",
|
||||
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
|
||||
"custom_weights": null,
|
||||
"custom_vae": null,
|
||||
"precision": "fp16",
|
||||
"device": "AMD Radeon RX 7900 XTX => vulkan://0",
|
||||
"ondemand": false,
|
||||
"repeatable_seeds": false,
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
}
|
||||
@@ -76,8 +76,8 @@ def launch_webui(address):
|
||||
def webui():
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
nodicon_loc,
|
||||
nodlogo_loc,
|
||||
amdicon_loc,
|
||||
amdlogo_loc,
|
||||
)
|
||||
|
||||
launch_api = cmd_opts.api
|
||||
@@ -172,9 +172,9 @@ def webui():
|
||||
analytics_enabled=False,
|
||||
title="Shark Studio 2.0 Beta",
|
||||
) as studio_web:
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
amd_logo = Image.open(amdlogo_loc)
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
value=amd_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="tab_bar_logo",
|
||||
@@ -209,7 +209,7 @@ def webui():
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=cmd_opts.server_port,
|
||||
favicon_path=nodicon_loc,
|
||||
favicon_path=amdicon_loc,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -138,6 +138,7 @@ with gr.Blocks(title="Chat") as chat_element:
|
||||
label="Run in streaming mode (requires recompilation)",
|
||||
value=True,
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
|
||||
@@ -367,7 +367,7 @@ footer {
|
||||
#tab_bar_logo .image-container {
|
||||
object-fit: scale-down;
|
||||
position: absolute !important;
|
||||
top: 14px;
|
||||
top: 10px;
|
||||
right: 0px;
|
||||
height: 36px;
|
||||
}
|
||||
}
|
||||
|
||||
BIN
apps/shark_studio/web/ui/logos/amd-icon.jpg
Normal file
BIN
apps/shark_studio/web/ui/logos/amd-icon.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.1 KiB |
BIN
apps/shark_studio/web/ui/logos/amd-logo.jpg
Normal file
BIN
apps/shark_studio/web/ui/logos/amd-logo.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.4 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 16 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 10 KiB |
@@ -10,7 +10,7 @@ from apps.shark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.shark_studio.web.ui.utils import nodlogo_loc
|
||||
from apps.shark_studio.web.ui.utils import amdlogo_loc
|
||||
from apps.shark_studio.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
@@ -60,7 +60,7 @@ def output_subdirs() -> list[str]:
|
||||
# --- Define UI layout for Gradio
|
||||
|
||||
with gr.Blocks() as outputgallery_element:
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
amd_logo = Image.open(amdlogo_loc)
|
||||
|
||||
with gr.Row(elem_id="outputgallery_gallery"):
|
||||
# needed to workaround gradio issue:
|
||||
@@ -73,7 +73,7 @@ with gr.Blocks() as outputgallery_element:
|
||||
with gr.Column(scale=6):
|
||||
logo = gr.Image(
|
||||
label="Getting subdirectories...",
|
||||
value=nod_logo,
|
||||
value=amd_logo,
|
||||
interactive=False,
|
||||
visible=True,
|
||||
show_label=True,
|
||||
|
||||
@@ -14,11 +14,12 @@ from apps.shark_studio.web.utils.file_utils import (
|
||||
get_checkpoints_path,
|
||||
get_checkpoints,
|
||||
get_configs_path,
|
||||
write_default_sd_config,
|
||||
write_default_sd_configs,
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
shark_sd_fn_dict_input,
|
||||
cancel_sd,
|
||||
unload_sd,
|
||||
)
|
||||
from apps.shark_studio.api.controlnet import (
|
||||
cnet_preview,
|
||||
@@ -32,7 +33,7 @@ from apps.shark_studio.modules.img_processing import (
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
nodlogo_loc,
|
||||
amdlogo_loc,
|
||||
none_to_str_none,
|
||||
str_none_to_none,
|
||||
)
|
||||
@@ -119,7 +120,7 @@ def pull_sd_configs(
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
controlnets,
|
||||
embeddings,
|
||||
@@ -178,7 +179,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_json["device"],
|
||||
sd_json["target_triple"],
|
||||
sd_json["ondemand"],
|
||||
sd_json["repeatable_seeds"],
|
||||
sd_json["compiled_pipeline"],
|
||||
sd_json["resample_type"],
|
||||
sd_json["controlnets"],
|
||||
sd_json["embeddings"],
|
||||
@@ -256,7 +257,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
allow_custom_value=False,
|
||||
)
|
||||
target_triple = gr.Textbox(
|
||||
elem_id="triple",
|
||||
elem_id="target_triple",
|
||||
label="Architecture",
|
||||
value="",
|
||||
)
|
||||
@@ -282,6 +283,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2-1-base",
|
||||
choices=sd_default_models,
|
||||
allow_custom_value=True,
|
||||
) # base_model_id
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
@@ -586,21 +588,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
object_fit="fit",
|
||||
preview=True,
|
||||
)
|
||||
with gr.Row():
|
||||
std_output = gr.Textbox(
|
||||
value=f"{sd_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=True,
|
||||
label="Log",
|
||||
show_copy_button=True,
|
||||
)
|
||||
sd_element.load(
|
||||
logger.read_sd_logs, None, std_output, every=1
|
||||
)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
@@ -619,17 +606,15 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
cmd_opts.repeatable_seeds,
|
||||
label="Use Repeatable Seeds for Batches",
|
||||
compiled_pipeline = gr.Checkbox(
|
||||
False,
|
||||
label="Faster txt2img (SDXL only)",
|
||||
)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Start")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
unload = gr.Button("Unload Models")
|
||||
unload.click(
|
||||
fn=unload_sd,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
@@ -644,7 +629,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
get_configs_path(),
|
||||
"default_sd_config.json",
|
||||
)
|
||||
write_default_sd_config(default_config_file)
|
||||
write_default_sd_configs(get_configs_path())
|
||||
sd_json = gr.JSON(
|
||||
elem_classes=["fill"],
|
||||
value=view_json_file(default_config_file),
|
||||
@@ -700,7 +685,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
@@ -717,6 +702,22 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
with gr.Tab(label="Log", id=103) as sd_tab_log:
|
||||
with gr.Row():
|
||||
std_output = gr.Textbox(
|
||||
value=f"{sd_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=True,
|
||||
label="Log",
|
||||
show_copy_button=True,
|
||||
)
|
||||
sd_element.load(
|
||||
logger.read_sd_logs, None, std_output, every=1
|
||||
)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
|
||||
pull_kwargs = dict(
|
||||
fn=pull_sd_configs,
|
||||
@@ -740,7 +741,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
|
||||
@@ -10,8 +10,8 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
nodicon_loc = resource_path("logos/nod-icon.png")
|
||||
amdlogo_loc = resource_path("logos/amd-logo.jpg")
|
||||
amdicon_loc = resource_path("logos/amd-icon.jpg")
|
||||
|
||||
|
||||
class HSLHue(IntEnum):
|
||||
|
||||
95
apps/shark_studio/web/utils/default_configs.py
Normal file
95
apps/shark_studio/web/utils/default_configs.py
Normal file
@@ -0,0 +1,95 @@
|
||||
default_sd_config = r"""{
|
||||
"prompt": [
|
||||
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
|
||||
],
|
||||
"negative_prompt": [
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
"sd_init_image": [null],
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"steps": 50,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": "-1",
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "EulerDiscrete",
|
||||
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
|
||||
"custom_weights": null,
|
||||
"custom_vae": null,
|
||||
"precision": "fp16",
|
||||
"device": "",
|
||||
"target_triple": "",
|
||||
"ondemand": false,
|
||||
"compiled_pipeline": false,
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
}"""
|
||||
|
||||
sdxl_30steps = r"""{
|
||||
"prompt": [
|
||||
"a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal"
|
||||
],
|
||||
"negative_prompt": [
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
"sd_init_image": [null],
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"steps": 30,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": "-1",
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "EulerDiscrete",
|
||||
"base_model_id": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"custom_weights": null,
|
||||
"custom_vae": null,
|
||||
"precision": "fp16",
|
||||
"device": "",
|
||||
"target_triple": "",
|
||||
"ondemand": false,
|
||||
"compiled_pipeline": true,
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
}"""
|
||||
|
||||
sdxl_turbo = r"""{
|
||||
"prompt": [
|
||||
"A cat wearing a hat that says 'TURBO' on it. The cat is sitting on a skateboard."
|
||||
],
|
||||
"negative_prompt": [
|
||||
""
|
||||
],
|
||||
"sd_init_image": [null],
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"steps": 2,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 0,
|
||||
"seed": "-1",
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "EulerAncestralDiscrete",
|
||||
"base_model_id": "stabilityai/sdxl-turbo",
|
||||
"custom_weights": null,
|
||||
"custom_vae": null,
|
||||
"precision": "fp16",
|
||||
"device": "",
|
||||
"target_triple": "",
|
||||
"ondemand": false,
|
||||
"compiled_pipeline": true,
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
}"""
|
||||
|
||||
default_sd_configs = {
|
||||
"default_sd_config.json": default_sd_config,
|
||||
"sdxl-30steps.json": sdxl_30steps,
|
||||
"sdxl-turbo.json": sdxl_turbo,
|
||||
}
|
||||
@@ -11,39 +11,14 @@ checkpoints_filetypes = (
|
||||
"*.safetensors",
|
||||
)
|
||||
|
||||
default_sd_config = r"""{
|
||||
"prompt": [
|
||||
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
|
||||
],
|
||||
"negative_prompt": [
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
"sd_init_image": [null],
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"steps": 50,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": "-1",
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "EulerDiscrete",
|
||||
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
|
||||
"custom_weights": null,
|
||||
"custom_vae": null,
|
||||
"precision": "fp16",
|
||||
"device": "AMD Radeon RX 7900 XTX => vulkan://0",
|
||||
"ondemand": false,
|
||||
"repeatable_seeds": false,
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
}"""
|
||||
from apps.shark_studio.web.utils.default_configs import default_sd_configs
|
||||
|
||||
|
||||
def write_default_sd_config(path):
|
||||
with open(path, "w") as f:
|
||||
f.write(default_sd_config)
|
||||
def write_default_sd_configs(path):
|
||||
for key in default_sd_configs.keys():
|
||||
config_fpath = os.path.join(path, key)
|
||||
with open(config_fpath, "w") as f:
|
||||
f.write(default_sd_configs[key])
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
|
||||
@@ -10,7 +10,7 @@ from utils import get_datasets
|
||||
|
||||
shark_root = Path(__file__).parent.parent
|
||||
demo_css = shark_root.joinpath("web/demo.css").resolve()
|
||||
nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/nod-logo.png")
|
||||
nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/amd-logo.jpg")
|
||||
|
||||
|
||||
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
|
||||
@@ -8,8 +8,9 @@ wheel
|
||||
|
||||
torch==2.3.0
|
||||
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@deprecated-constraints#subdirectory=models
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
|
||||
|
||||
# SHARK Runner
|
||||
tqdm
|
||||
|
||||
Reference in New Issue
Block a user