mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Enable UI / bugfixes / tweaks
This commit is contained in:
134
apps/shark_studio/api/controlnet.py
Normal file
134
apps/shark_studio/api/controlnet.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
|
||||
|
||||
|
||||
class control_adapter:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
):
|
||||
self.model = None
|
||||
|
||||
def export_control_adapter_model(model_keyword):
|
||||
return None
|
||||
|
||||
def export_xl_control_adapter_model(model_keyword):
|
||||
return None
|
||||
|
||||
|
||||
class preprocessors:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
):
|
||||
self.model = None
|
||||
|
||||
def export_controlnet_model(model_keyword):
|
||||
return None
|
||||
|
||||
|
||||
control_adapter_map = {
|
||||
"sd15": {
|
||||
"canny": {"initializer": control_adapter.export_control_adapter_model},
|
||||
"openpose": {
|
||||
"initializer": control_adapter.export_control_adapter_model
|
||||
},
|
||||
"scribble": {
|
||||
"initializer": control_adapter.export_control_adapter_model
|
||||
},
|
||||
"zoedepth": {
|
||||
"initializer": control_adapter.export_control_adapter_model
|
||||
},
|
||||
},
|
||||
"sdxl": {
|
||||
"canny": {
|
||||
"initializer": control_adapter.export_xl_control_adapter_model
|
||||
},
|
||||
},
|
||||
}
|
||||
preprocessor_model_map = {
|
||||
"canny": {"initializer": preprocessors.export_controlnet_model},
|
||||
"openpose": {"initializer": preprocessors.export_controlnet_model},
|
||||
"scribble": {"initializer": preprocessors.export_controlnet_model},
|
||||
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
|
||||
}
|
||||
|
||||
|
||||
class PreprocessorModel:
|
||||
def __init__(
|
||||
self,
|
||||
hf_model_id,
|
||||
device,
|
||||
):
|
||||
self.model = None
|
||||
|
||||
def compile(self, device):
|
||||
print("compile not implemented for preprocessor.")
|
||||
return
|
||||
|
||||
def run(self, inputs):
|
||||
print("run not implemented for preprocessor.")
|
||||
return
|
||||
|
||||
|
||||
def cnet_preview(model, input_img, stencils, images, preprocessed_hints):
|
||||
if isinstance(input_image, PIL.Image.Image):
|
||||
img_dict = {
|
||||
"background": None,
|
||||
"layers": [None],
|
||||
"composite": input_image,
|
||||
}
|
||||
input_image = EditorValue(img_dict)
|
||||
images[index] = input_image
|
||||
if model:
|
||||
stencils[index] = model
|
||||
match model:
|
||||
case "canny":
|
||||
canny = CannyDetector()
|
||||
result = canny(
|
||||
np.array(input_image["composite"]),
|
||||
100,
|
||||
200,
|
||||
)
|
||||
preprocessed_hints[index] = Image.fromarray(result)
|
||||
return (
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case "openpose":
|
||||
openpose = OpenposeDetector()
|
||||
result = openpose(np.array(input_image["composite"]))
|
||||
preprocessed_hints[index] = Image.fromarray(result[0])
|
||||
return (
|
||||
Image.fromarray(result[0]),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case "zoedepth":
|
||||
zoedepth = ZoeDetector()
|
||||
result = zoedepth(np.array(input_image["composite"]))
|
||||
preprocessed_hints[index] = Image.fromarray(result)
|
||||
return (
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case "scribble":
|
||||
preprocessed_hints[index] = input_image["composite"]
|
||||
return (
|
||||
input_image["composite"],
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case _:
|
||||
preprocessed_hints[index] = None
|
||||
return (
|
||||
None,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
@@ -26,22 +26,21 @@ def imports():
|
||||
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
# from apps.shark_studio.modules import shared_init
|
||||
# shared_init.initialize()
|
||||
# startup_timer.record("initialize shared")
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj._init()
|
||||
startup_timer.record("initialize globals")
|
||||
|
||||
from apps.shark_studio.modules import (
|
||||
processing,
|
||||
gradio_extensons,
|
||||
ui,
|
||||
img_processing,
|
||||
) # noqa: F401
|
||||
from apps.shark_studio.modules.schedulers import scheduler_model_map
|
||||
|
||||
startup_timer.record("other imports")
|
||||
|
||||
|
||||
def initialize():
|
||||
configure_sigint_handler()
|
||||
configure_opts_onchange()
|
||||
|
||||
# from apps.shark_studio.modules import modelloader
|
||||
# modelloader.cleanup_models()
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from turbine_models.custom_models.sd_inference import clip, unet, vae
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module
|
||||
from apps.shark_studio.api.utils import get_resource_path
|
||||
from apps.shark_studio.api.controlnet import control_adapter_map
|
||||
from apps.shark_studio.web.utils.state import status_label
|
||||
from apps.shark_studio.modules.pipeline import SharkPipelineBase
|
||||
import iree.runtime as ireert
|
||||
import gc
|
||||
import torch
|
||||
import gradio as gr
|
||||
|
||||
sd_model_map = {
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
@@ -86,16 +90,15 @@ sd_model_map = {
|
||||
|
||||
|
||||
class StableDiffusion(SharkPipelineBase):
|
||||
|
||||
# This class is responsible for executing image generation and creating
|
||||
# /managing a set of compiled modules to run Stable Diffusion. The init
|
||||
# aims to be as general as possible, and the class will infer and compile
|
||||
# a list of necessary modules or a combined "pipeline module" for a
|
||||
# specified job based on the inference task.
|
||||
#
|
||||
#
|
||||
# custom_model_ids: a dict of submodel + HF ID pairs for custom submodels.
|
||||
# e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"}
|
||||
#
|
||||
#
|
||||
# embeddings: a dict of embedding checkpoints or model IDs to use when
|
||||
# initializing the compiled modules.
|
||||
|
||||
@@ -107,7 +110,6 @@ class StableDiffusion(SharkPipelineBase):
|
||||
precision: str = "fp16",
|
||||
device: str = None,
|
||||
custom_model_map: dict = {},
|
||||
custom_weights_map: dict = {},
|
||||
embeddings: dict = {},
|
||||
import_ir: bool = True,
|
||||
):
|
||||
@@ -118,12 +120,185 @@ class StableDiffusion(SharkPipelineBase):
|
||||
self.iree_module_dict = None
|
||||
self.get_compiled_map()
|
||||
|
||||
def prepare_pipeline(self, scheduler, custom_model_map):
|
||||
return None
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompt,
|
||||
):
|
||||
return result_output,
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints,
|
||||
):
|
||||
return None, None, None, None, None
|
||||
|
||||
|
||||
# NOTE: Each `hf_model_id` should have its own starting configuration.
|
||||
|
||||
# model_vmfb_key = ""
|
||||
|
||||
|
||||
def shark_sd_fn(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image_dict,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
base_model_id: str,
|
||||
custom_weights: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
lora_weights: str | list,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
control_mode: str,
|
||||
stencils: list,
|
||||
images: list,
|
||||
preprocessed_hints: list,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
|
||||
for i, stencil in enumerate(stencils):
|
||||
if images[i] is None and stencil is not None:
|
||||
continue
|
||||
elif stencil is None and any(
|
||||
img is not None for img in [images[i], preprocessed_hints[i]]
|
||||
):
|
||||
images[i] = None
|
||||
preprocessed_hints[i] = None
|
||||
elif images[i] is not None:
|
||||
if isinstance(images[i], dict):
|
||||
images[i] = images[i]["composite"]
|
||||
images[i] = images[i].convert("RGB")
|
||||
|
||||
if isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
elif image_dict:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
if image:
|
||||
(
|
||||
image,
|
||||
_,
|
||||
_,
|
||||
) = resize_stencil(image, width, height)
|
||||
is_img2img = True
|
||||
print("Performing Stable Diffusion Pipeline setup...")
|
||||
|
||||
device_id = None
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
custom_model_map = {}
|
||||
if custom_weights != "None":
|
||||
custom_model_map["unet"] = {"custom_weights": custom_weights}
|
||||
if custom_vae != "None":
|
||||
custom_model_map["vae"] = {"custom_weights": custom_vae}
|
||||
if stencils:
|
||||
for i, stencil in enumerate(stencils):
|
||||
if "xl" not in base_model_id.lower():
|
||||
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
|
||||
"runwayml/stable-diffusion-v1-5"
|
||||
][stencil]
|
||||
else:
|
||||
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
|
||||
"stabilityai/stable-diffusion-xl-1.0"
|
||||
][stencil]
|
||||
|
||||
submit_pipe_kwargs = {
|
||||
"base_model_id": base_model_id,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
"device": device,
|
||||
"custom_model_map": custom_model_map,
|
||||
"import_ir": cmd_opts.import_mlir,
|
||||
"is_img2img": is_img2img,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
"scheduler": scheduler,
|
||||
"custom_model_map": custom_model_map,
|
||||
"embeddings": lora_weights,
|
||||
}
|
||||
submit_run_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"steps": steps,
|
||||
"strength": strength,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
"ondemand": ondemand,
|
||||
"repeatable_seeds": repeatable_seeds,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"preprocessed_hints": preprocessed_hints,
|
||||
}
|
||||
|
||||
global sd_pipe
|
||||
global sd_pipe_kwargs
|
||||
|
||||
if sd_pipe_kwargs and sd_pipe_kwargs != submit_pipe_kwargs:
|
||||
sd_pipe = None
|
||||
sd_pipe_kwargs = submit_pipe_kwargs
|
||||
gc.collect()
|
||||
|
||||
if sd_pipe is None:
|
||||
history[-1][-1] = "Getting the pipeline ready..."
|
||||
yield history, ""
|
||||
|
||||
# Initializes the pipeline and retrieves IR based on all
|
||||
# parameters that are static in the turbine output format,
|
||||
# which is currently MLIR in the torch dialect.
|
||||
|
||||
sd_pipe = SharkStableDiffusionPipeline(
|
||||
**submit_pipe_kwargs,
|
||||
)
|
||||
|
||||
sd_pipe.prepare_pipe(**submit_prep_kwargs)
|
||||
|
||||
for prompt, msg, exec_time in progress.tqdm(
|
||||
out_imgs=sd_pipe.generate_images(**submit_run_kwargs),
|
||||
desc="Generating Image...",
|
||||
):
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
seeds[current_batch],
|
||||
extra_info,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
), stencils, images
|
||||
|
||||
return generated_imgs, text_output, "", stencils, images
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
print("Inject call to cancel longer API calls.")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sd = StableDiffusion(
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
import glob
|
||||
from random import (
|
||||
randint,
|
||||
seed as seed_random,
|
||||
@@ -12,6 +13,19 @@ from random import (
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
# TODO: migrate these utils to studio
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
checkpoints_filetypes = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
)
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
@@ -75,32 +89,119 @@ def get_available_devices():
|
||||
return available_devices
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in cmd_opts.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in cmd_opts.device:
|
||||
cmd_opts.device = "cuda"
|
||||
elif "metal" in cmd_opts.device:
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_metal_target_platform:
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_metal_target_platform}."
|
||||
)
|
||||
elif "cpu" in cmd_opts.device:
|
||||
cmd_opts.device = "cpu"
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if cmd_opts.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if cmd_opts.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
return Path(
|
||||
cmd_opts.output_dir
|
||||
if cmd_opts.output_dir
|
||||
cmd_opts.output_dir
|
||||
if cmd_opts.output_dir
|
||||
else get_resource_path("..\web\generated_imgs")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_generated_imgs_todays_subdir() -> str:
|
||||
return dt.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
def get_checkpoints_path(model = ""):
|
||||
def create_checkpoint_folders():
|
||||
dir = ["vae", "lora"]
|
||||
if not cmd_opts.ckpt_dir:
|
||||
dir.insert(0, "models")
|
||||
else:
|
||||
if not os.path.isdir(cmd_opts.ckpt_dir):
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, "
|
||||
f"{args.ckpt_dir} folder does not exists."
|
||||
)
|
||||
for root in dir:
|
||||
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_checkpoints_path(model=""):
|
||||
return get_resource_path(f"..\web\models\{model}")
|
||||
|
||||
|
||||
def get_checkpoints(path):
|
||||
files = []
|
||||
for file in
|
||||
def get_checkpoints(model="models"):
|
||||
ckpt_files = []
|
||||
file_types = checkpoints_filetypes
|
||||
if model == "lora":
|
||||
file_types = file_types + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
files = [
|
||||
os.path.basename(x)
|
||||
for x in glob.glob(os.path.join(get_checkpoints_path(model), extn))
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_checkpoint_pathfile(checkpoint_name, model="models"):
|
||||
return os.path.join(get_checkpoints_path(model), checkpoint_name)
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
@@ -142,6 +243,30 @@ def get_device_mapping(driver, key_combination=3):
|
||||
return device_map
|
||||
|
||||
|
||||
def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags = []
|
||||
if len(cmd_opts.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
)
|
||||
if "rocm" in cmd_opts.device:
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
if cmd_opts.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
if cmd_opts.data_tiling == False:
|
||||
iree_flags.append("--iree-opt-data-tiling=False")
|
||||
|
||||
if "vae" not in model:
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user
|
||||
selected execution device
|
||||
@@ -248,6 +373,7 @@ def parse_seed_input(seed_input: str | list | int):
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed: int | str):
|
||||
@@ -258,6 +384,7 @@ def sanitize_seed(seed: int | str):
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
|
||||
66
apps/shark_studio/modules/checkpoint_proc.py
Normal file
66
apps/shark_studio/modules/checkpoint_proc.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def get_path_to_diffusers_checkpoint(custom_weights):
|
||||
path = Path(custom_weights)
|
||||
diffusers_path = path.parent.absolute()
|
||||
diffusers_directory_name = os.path.join("diffusers", path.stem)
|
||||
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
|
||||
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
|
||||
path_to_diffusers = complete_path_to_diffusers.as_posix()
|
||||
return path_to_diffusers
|
||||
|
||||
|
||||
def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
if next(Path(path_to_diffusers).iterdir(), None):
|
||||
print("Checkpoint already loaded at : ", path_to_diffusers)
|
||||
return
|
||||
else:
|
||||
print(
|
||||
"Diffusers' checkpoint will be identified here : ",
|
||||
path_to_diffusers,
|
||||
)
|
||||
from_safetensors = (
|
||||
True if custom_weights.lower().endswith(".safetensors") else False
|
||||
)
|
||||
# EMA weights usually yield higher quality images for inference but
|
||||
# non-EMA weights have been yielding better results in our case.
|
||||
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
|
||||
# they want to go for EMA weight extraction or not.
|
||||
extract_ema = False
|
||||
print(
|
||||
"Loading diffusers' pipeline from original stable diffusion checkpoint"
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path_or_dict=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
num_in_channels=num_in_channels,
|
||||
)
|
||||
pipe.save_pretrained(path_to_diffusers)
|
||||
print("Loading complete")
|
||||
|
||||
|
||||
def convert_original_vae(vae_checkpoint):
|
||||
vae_state_dict = {}
|
||||
for key in list(vae_checkpoint.keys()):
|
||||
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
|
||||
|
||||
config_url = (
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
|
||||
"main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=512)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
vae_state_dict, vae_config
|
||||
)
|
||||
return converted_vae_checkpoint
|
||||
@@ -1,5 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import json
|
||||
import safetensors
|
||||
from safetensors.torch import load_file
|
||||
from apps.shark_studio.api.utils import get_checkpoint_pathfile
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
@@ -109,3 +114,58 @@ def update_lora_weight(model, use_lora, model_name):
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_lora_metadata(lora_filename):
|
||||
# get the metadata from the file
|
||||
filename = get_checkpoint_pathfile(lora_filename, "lora")
|
||||
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
# guard clause for if there isn't any metadata
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# metadata is a dictionary of strings, the values of the keys we're
|
||||
# interested in are actually json, and need to be loaded as such
|
||||
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
|
||||
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
|
||||
tag_dirs = [dir for dir in tag_frequencies.keys()]
|
||||
|
||||
# gather the tag frequency information for all the datasets trained
|
||||
all_frequencies = {}
|
||||
for dataset in tag_dirs:
|
||||
frequencies = sorted(
|
||||
[entry for entry in tag_frequencies[dataset].items()],
|
||||
reverse=True,
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
|
||||
# get a figure for the total number of images processed for this dataset
|
||||
# either then number actually listed or in its dataset_dir entry or
|
||||
# the highest frequency's number if that doesn't exist
|
||||
img_count = dataset_dirs.get(dir, {}).get(
|
||||
"img_count", frequencies[0][1]
|
||||
)
|
||||
|
||||
# add the dataset frequencies to the overall frequencies replacing the
|
||||
# frequency counts on the tags with a percentage/ratio
|
||||
all_frequencies.update(
|
||||
[(entry[0], entry[1] / img_count) for entry in frequencies]
|
||||
)
|
||||
|
||||
trained_model_id = " ".join(
|
||||
[
|
||||
metadata.get("ss_sd_model_hash", ""),
|
||||
metadata.get("ss_sd_model_name", ""),
|
||||
metadata.get("ss_base_model_version", ""),
|
||||
]
|
||||
).strip()
|
||||
|
||||
# return the topmost <count> of all frequencies in all datasets
|
||||
return {
|
||||
"model": trained_model_id,
|
||||
"frequencies": sorted(
|
||||
all_frequencies.items(), reverse=True, key=lambda x: x[1]
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from
|
||||
import os
|
||||
import sys
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info=None):
|
||||
@@ -10,43 +14,45 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", cmd_opts.prompts[0][:15])
|
||||
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
|
||||
|
||||
img_model = args.hf_model_id
|
||||
if args.ckpt_loc:
|
||||
img_model = Path(os.path.basename(args.ckpt_loc)).stem
|
||||
img_model = cmd_opts.hf_model_id
|
||||
if cmd_opts.ckpt_loc:
|
||||
img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem
|
||||
|
||||
img_vae = None
|
||||
if args.custom_vae:
|
||||
img_vae = Path(os.path.basename(args.custom_vae)).stem
|
||||
if cmd_opts.custom_vae:
|
||||
img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem
|
||||
|
||||
img_lora = None
|
||||
if args.use_lora:
|
||||
img_lora = Path(os.path.basename(args.use_lora)).stem
|
||||
if cmd_opts.use_lora:
|
||||
img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
if cmd_opts.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
if cmd_opts.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
if args.use_hiresfix:
|
||||
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
|
||||
if cmd_opts.use_hiresfix:
|
||||
png_size_text = (
|
||||
f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
|
||||
)
|
||||
else:
|
||||
png_size_text = f"{args.width}x{args.height}"
|
||||
png_size_text = f"{cmd_opts.width}x{cmd_opts.height}"
|
||||
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}"
|
||||
f"\nNegative prompt: {args.negative_prompts[0]}"
|
||||
f"\nSteps: {args.steps},"
|
||||
f"Sampler: {args.scheduler}, "
|
||||
f"CFG scale: {args.guidance_scale}, "
|
||||
f"{cmd_opts.prompts[0]}"
|
||||
f"\nNegative prompt: {cmd_opts.negative_prompts[0]}"
|
||||
f"\nSteps: {cmd_opts.steps},"
|
||||
f"Sampler: {cmd_opts.scheduler}, "
|
||||
f"CFG scale: {cmd_opts.guidance_scale}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
@@ -56,9 +62,9 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
if cmd_opts.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not "
|
||||
f"[ERROR] Format {cmd_opts.output_img_format} is not "
|
||||
f"supported yet. Image saved as png instead."
|
||||
f"Supported formats: png / jpg"
|
||||
)
|
||||
@@ -68,18 +74,20 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
# importance for each data point. Something to consider.
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"SCHEDULER": cmd_opts.scheduler,
|
||||
"PROMPT": cmd_opts.prompts[0],
|
||||
"NEG_PROMPT": cmd_opts.negative_prompts[0],
|
||||
"SEED": img_seed,
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height
|
||||
if not args.use_hiresfix
|
||||
else args.hiresfix_height,
|
||||
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"CFG_SCALE": cmd_opts.guidance_scale,
|
||||
"PRECISION": cmd_opts.precision,
|
||||
"STEPS": cmd_opts.steps,
|
||||
"HEIGHT": cmd_opts.height
|
||||
if not cmd_opts.use_hiresfix
|
||||
else cmd_opts.hiresfix_height,
|
||||
"WIDTH": cmd_opts.width
|
||||
if not cmd_opts.use_hiresfix
|
||||
else cmd_opts.hiresfix_width,
|
||||
"MAX_LENGTH": cmd_opts.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
"LORA": img_lora,
|
||||
@@ -95,37 +103,23 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
if args.save_metadata_to_json:
|
||||
if cmd_opts.save_metadata_to_json:
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
def get_generation_text_info(seeds, device):
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, "
|
||||
f"guidance_scale={args.guidance_scale}, "
|
||||
f"seed={seeds}"
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={args.height}x{args.width}, "
|
||||
if not args.use_hiresfix
|
||||
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
|
||||
)
|
||||
text_output += (
|
||||
f"batch_count={args.batch_count}, "
|
||||
f"batch_size={args.batch_size}, "
|
||||
f"max_length={args.max_length}"
|
||||
)
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
"Bilinear": Image.Resampling.BILINEAR,
|
||||
"Bicubic": Image.Resampling.BICUBIC,
|
||||
"Hamming": Image.Resampling.HAMMING,
|
||||
"Box": Image.Resampling.BOX,
|
||||
}
|
||||
|
||||
return text_output
|
||||
resampler_list = resamplers.keys()
|
||||
|
||||
|
||||
# For stencil, the input image can be of any size, but we need to ensure that
|
||||
@@ -133,7 +127,7 @@ def get_generation_text_info(seeds, device):
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image, width, height):
|
||||
def resize_stencil(image: Image.Image, width, height, resampler_type=None):
|
||||
aspect_ratio = width / height
|
||||
min_size = min(width, height)
|
||||
if min_size < 128:
|
||||
@@ -166,6 +160,9 @@ def resize_stencil(image: Image.Image, width, height):
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
new_image = image.resize((n_width, n_height))
|
||||
if resampler_type in resamplers:
|
||||
resampler = resamplers[resampler_type]
|
||||
else:
|
||||
resampler = resamplers["Nearest Neighbor"]
|
||||
new_image = image.resize((n_width, n_height), resampler=resampler)
|
||||
return new_image, n_width, n_height
|
||||
|
||||
|
||||
71
apps/shark_studio/modules/pipeline.py
Normal file
71
apps/shark_studio/modules/pipeline.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module
|
||||
|
||||
|
||||
class SharkPipelineBase:
|
||||
# This class is a lightweight base for managing an
|
||||
# inference API class. It should provide methods for:
|
||||
# - compiling a set (model map) of torch IR modules
|
||||
# - preparing weights for an inference job
|
||||
# - loading weights for an inference job
|
||||
# - utilites like benchmarks, tests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_map: dict,
|
||||
device: str,
|
||||
import_mlir: bool = True,
|
||||
):
|
||||
self.model_map = model_map
|
||||
self.device = device
|
||||
self.import_mlir = import_mlir
|
||||
|
||||
def import_torch_ir(self, base_model_id):
|
||||
for submodel in self.model_map:
|
||||
hf_id = (
|
||||
submodel["custom_hf_id"]
|
||||
if submodel["custom_hf_id"]
|
||||
else base_model_id
|
||||
)
|
||||
torch_ir = submodel["initializer"](
|
||||
hf_id, **submodel["init_kwargs"], compile_to="torch"
|
||||
)
|
||||
submodel["tempfile_name"] = get_resource_path(
|
||||
f"{submodel}.torch.tempfile"
|
||||
)
|
||||
with open(submodel["tempfile_name"], "w+") as f:
|
||||
f.write(torch_ir)
|
||||
del torch_ir
|
||||
gc.collect()
|
||||
|
||||
def load_vmfb(self, submodel):
|
||||
if self.iree_module_dict[submodel]:
|
||||
print(
|
||||
f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}"
|
||||
)
|
||||
elif self.model_map[submodel]["tempfile_name"]:
|
||||
submodel["tempfile_name"]
|
||||
|
||||
return submodel["vmfb"]
|
||||
|
||||
def merge_custom_map(self, custom_model_map):
|
||||
for submodel in custom_model_map:
|
||||
for key in submodel:
|
||||
self.model_map[submodel][key] = key
|
||||
print(self.model_map)
|
||||
|
||||
def get_compiled_map(self, device) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
for submodel in self.model_map:
|
||||
if not self.iree_module_dict[submodel][vmfb]:
|
||||
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
||||
submodel.tempfile_name,
|
||||
device=self.device,
|
||||
frontend="torch",
|
||||
)
|
||||
# TODO: delete the temp file
|
||||
|
||||
def run(self, submodel, inputs):
|
||||
return
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("-", "_")
|
||||
@@ -2,7 +2,7 @@ import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from apps.stable_diffusion.src.utils.resamplers import resampler_list
|
||||
from apps.shark_studio.modules.img_processing import resampler_list
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
@@ -36,7 +36,7 @@ p.add_argument(
|
||||
nargs="+",
|
||||
default=[
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
"mountains at high speeds with smokes coming off the tires, front "
|
||||
"mountains at high speeds with smoke coming off the tires, front "
|
||||
"angle, front point of view, trees in the mountains of the "
|
||||
"background, ((sharp focus))"
|
||||
],
|
||||
@@ -306,21 +306,6 @@ p.add_argument(
|
||||
"downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Attempts to load the model from a precompiled flat-buffer "
|
||||
"and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Saves the compiled flat-buffer to the local directory.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=False,
|
||||
@@ -446,7 +431,7 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ondemand",
|
||||
"--lowvram",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Load and unload models for low VRAM.",
|
||||
@@ -469,10 +454,10 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--autogen",
|
||||
type=bool,
|
||||
default="False",
|
||||
help="Only used for a gradio workaround.",
|
||||
"--custom_model_map",
|
||||
type=str,
|
||||
default="",
|
||||
help="path to custom model map to import. This should be a .json file",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
@@ -612,6 +597,13 @@ p.add_argument(
|
||||
# Web UI flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--webui",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="controls whether the webui is launched.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
@@ -764,8 +756,8 @@ p.add_argument(
|
||||
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
cmd_opts, unknown = p.parse_known_args()
|
||||
if cmd_opts.import_debug:
|
||||
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
|
||||
os.getcwd(), args.hf_model_id.replace("/", "_")
|
||||
os.getcwd(), cmd_opts.hf_model_id.replace("/", "_")
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from apps.shark_studio. import sd_samplers, postprocessing, errors, restart
|
||||
from apps.shark_studio.modules.img_processing import sampler_list
|
||||
from sdapi_v1 import shark_sd_api
|
||||
from api.llm import chat_api
|
||||
|
||||
@@ -26,15 +26,21 @@ def decode_base64_to_image(encoding):
|
||||
raise HTTPException(status_code=500, detail="Requests not allowed")
|
||||
|
||||
if opts.api_forbid_local_requests and not verify_url(encoding):
|
||||
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Request to local resource not allowed"
|
||||
)
|
||||
|
||||
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
|
||||
headers = (
|
||||
{"user-agent": opts.api_useragent} if opts.api_useragent else {}
|
||||
)
|
||||
response = requests.get(encoding, timeout=30, headers=headers)
|
||||
try:
|
||||
image = Image.open(BytesIO(response.content))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image url"
|
||||
) from e
|
||||
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
@@ -42,32 +48,54 @@ def decode_base64_to_image(encoding):
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid encoded image"
|
||||
) from e
|
||||
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
|
||||
if opts.samples_format.lower() == 'png':
|
||||
if opts.samples_format.lower() == "png":
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
||||
image.save(
|
||||
output_bytes,
|
||||
format="PNG",
|
||||
pnginfo=(metadata if use_metadata else None),
|
||||
quality=opts.jpeg_quality,
|
||||
)
|
||||
|
||||
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
parameters = image.info.get('parameters', None)
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
||||
})
|
||||
parameters = image.info.get("parameters", None)
|
||||
exif_bytes = piexif.dump(
|
||||
{
|
||||
"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(
|
||||
parameters or "", encoding="unicode"
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
if opts.samples_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
image.save(
|
||||
output_bytes,
|
||||
format="JPEG",
|
||||
exif=exif_bytes,
|
||||
quality=opts.jpeg_quality,
|
||||
)
|
||||
else:
|
||||
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
image.save(
|
||||
output_bytes,
|
||||
format="WEBP",
|
||||
exif=exif_bytes,
|
||||
quality=opts.jpeg_quality,
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid image format")
|
||||
@@ -80,10 +108,11 @@ def encode_pil_to_base64(image):
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = False
|
||||
try:
|
||||
if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
|
||||
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
|
||||
import anyio # importing just so it can be placed on silent list
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
rich_available = True
|
||||
except Exception:
|
||||
@@ -95,35 +124,49 @@ def api_middleware(app: FastAPI):
|
||||
res: Response = await call_next(req)
|
||||
duration = str(round(time.time() - ts, 4))
|
||||
res.headers["X-Process-Time"] = duration
|
||||
endpoint = req.scope.get('path', 'err')
|
||||
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
|
||||
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
|
||||
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code=res.status_code,
|
||||
ver=req.scope.get('http_version', '0.0'),
|
||||
cli=req.scope.get('client', ('0:0.0.0', 0))[0],
|
||||
prot=req.scope.get('scheme', 'err'),
|
||||
method=req.scope.get('method', 'err'),
|
||||
endpoint=endpoint,
|
||||
duration=duration,
|
||||
))
|
||||
endpoint = req.scope.get("path", "err")
|
||||
if shared.cmd_opts.api_log and endpoint.startswith("/sdapi"):
|
||||
print(
|
||||
"API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format(
|
||||
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code=res.status_code,
|
||||
ver=req.scope.get("http_version", "0.0"),
|
||||
cli=req.scope.get("client", ("0:0.0.0", 0))[0],
|
||||
prot=req.scope.get("scheme", "err"),
|
||||
method=req.scope.get("method", "err"),
|
||||
endpoint=endpoint,
|
||||
duration=duration,
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
def handle_exception(request: Request, e: Exception):
|
||||
err = {
|
||||
"error": type(e).__name__,
|
||||
"detail": vars(e).get('detail', ''),
|
||||
"body": vars(e).get('body', ''),
|
||||
"detail": vars(e).get("detail", ""),
|
||||
"body": vars(e).get("body", ""),
|
||||
"errors": str(e),
|
||||
}
|
||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||
if not isinstance(
|
||||
e, HTTPException
|
||||
): # do not print backtrace on known httpexceptions
|
||||
message = f"API error: {request.method}: {request.url} {err}"
|
||||
if rich_available:
|
||||
print(message)
|
||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||
console.print_exception(
|
||||
show_locals=True,
|
||||
max_frames=2,
|
||||
extra_lines=1,
|
||||
suppress=[anyio, starlette],
|
||||
word_wrap=False,
|
||||
width=min([console.width, 200]),
|
||||
)
|
||||
else:
|
||||
errors.report(message, exc_info=True)
|
||||
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
||||
return JSONResponse(
|
||||
status_code=vars(e).get("status_code", 500),
|
||||
content=jsonable_encoder(err),
|
||||
)
|
||||
|
||||
@app.middleware("http")
|
||||
async def exception_handling(request: Request, call_next):
|
||||
@@ -143,52 +186,48 @@ def api_middleware(app: FastAPI):
|
||||
|
||||
class ApiCompat:
|
||||
def __init__(self, queue_lock: Lock):
|
||||
|
||||
self.router = APIRouter()
|
||||
self.app = FastAPI()
|
||||
self.queue_lock = queue_lock
|
||||
api_middleware(self.app)
|
||||
self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["post"])
|
||||
self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["post"])
|
||||
#self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"])
|
||||
#self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||
#self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||
#self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||
#self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||
#self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
#self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
#self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
#self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||
#self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
#self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
#self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||
#self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||
#self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
||||
#self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||
#self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||
#self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||
#self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||
#self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||
#self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||
#self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
#self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
#self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
#self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
#self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||
#self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||
#self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||
#self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||
#self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||
#self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
#self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
#self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
#self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||
|
||||
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"])
|
||||
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||
# self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||
# self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||
# self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||
# self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
# self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||
# self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||
# self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
||||
# self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||
# self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||
# self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||
# self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||
# self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||
# self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||
# self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
# self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
# self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||
# self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||
# self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||
# self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||
# self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||
# self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
# self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||
|
||||
# chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
self.add_api_route(
|
||||
"/v1/chat/completions", chat_api, methods=["post"]
|
||||
)
|
||||
self.add_api_route("/v1/chat/completions", chat_api, methods=["post"])
|
||||
self.add_api_route("/v1/completions", chat_api, methods=["post"])
|
||||
self.add_api_route("/chat/completions", chat_api, methods=["post"])
|
||||
self.add_api_route("/completions", chat_api, methods=["post"])
|
||||
@@ -196,16 +235,26 @@ class ApiCompat:
|
||||
"/v1/engines/codegen/completions", chat_api, methods=["post"]
|
||||
)
|
||||
if studio.cmd_opts.api_server_stop:
|
||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_studio, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/server-restart", self.restart_studio, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/server-stop", self.stop_studio, methods=["POST"])
|
||||
self.add_api_route(
|
||||
"/sdapi/v1/server-kill", self.kill_studio, methods=["POST"]
|
||||
)
|
||||
self.add_api_route(
|
||||
"/sdapi/v1/server-restart",
|
||||
self.restart_studio,
|
||||
methods=["POST"],
|
||||
)
|
||||
self.add_api_route(
|
||||
"/sdapi/v1/server-stop", self.stop_studio, methods=["POST"]
|
||||
)
|
||||
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
|
||||
def add_api_route(self, path:str, endpoint, **kwargs):
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if studio.cmd_opts.api_auth:
|
||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs
|
||||
return self.app.add_api_route(
|
||||
path, endpoint, dependencies=[Depends(self.auth)], **kwargs
|
||||
)
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
@@ -231,7 +280,13 @@ class ApiCompat:
|
||||
|
||||
def launch(self, server_name, port, root_path):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=studio.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=server_name,
|
||||
port=port,
|
||||
timeout_keep_alive=studio.cmd_opts.timeout_keep_alive,
|
||||
root_path=root_path,
|
||||
)
|
||||
|
||||
def kill_studio(self):
|
||||
restart.stop_program()
|
||||
@@ -246,7 +301,7 @@ class ApiCompat:
|
||||
studio.state.begin(job="preprocess")
|
||||
preprocess(**args)
|
||||
studio.state.end()
|
||||
return models.PreprocessResponse(info='preprocess complete')
|
||||
return models.PreprocessResponse(info="preprocess complete")
|
||||
except:
|
||||
studio.state.end()
|
||||
|
||||
|
||||
1
apps/shark_studio/web/configs/foo.json
Normal file
1
apps/shark_studio/web/configs/foo.json
Normal file
@@ -0,0 +1 @@
|
||||
{}
|
||||
@@ -3,12 +3,13 @@ import os
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
import apps.shark_studio.api.initializers as initialize
|
||||
|
||||
from ui.chat import chat_element
|
||||
from ui.sd import sd_element
|
||||
from ui.outputgallery import outputgallery_element
|
||||
|
||||
from modules import timer, initialize
|
||||
from apps.shark_studio.modules import timer
|
||||
|
||||
startup_timer = timer.startup_timer
|
||||
startup_timer.record("launcher")
|
||||
@@ -72,15 +73,13 @@ def launch_webui(address):
|
||||
|
||||
|
||||
def webui():
|
||||
from apps.shark_studio.shared_cmd_options import cmd_opts
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
launch_api = cmd_opts.api
|
||||
initialize.initialize()
|
||||
|
||||
from modules import shared, ui_tempdir, script_callbacks, ui, progress
|
||||
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
|
||||
@@ -131,16 +130,23 @@ def webui():
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.shark_studio.web.initializers import (
|
||||
config_gradio_tmp_imgs_folder,
|
||||
create_custom_models_folders,
|
||||
from apps.shark_studio.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
clear_tmp_mlir,
|
||||
clear_tmp_imgs,
|
||||
)
|
||||
from apps.shark_studio.api.utils import (
|
||||
create_checkpoint_folders,
|
||||
)
|
||||
|
||||
config_gradio_tmp_imgs_folder()
|
||||
import gradio as gr
|
||||
|
||||
config_tmp()
|
||||
clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
create_custom_models_folders()
|
||||
create_checkpoint_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
@@ -149,10 +155,7 @@ def webui():
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
from apps.shark_studio.web.ui import load_ui_from_script
|
||||
|
||||
# init global sd pipeline and config
|
||||
studio.state._init()
|
||||
# from apps.shark_studio.web.ui import load_ui_from_script
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
@@ -209,9 +212,9 @@ def webui():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from apps.shark_studio.shared_cmd_options import cmd_opts
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
if cmd_opts.nowebui:
|
||||
if cmd_opts.webui == False:
|
||||
api_only()
|
||||
else:
|
||||
webui()
|
||||
|
||||
55
apps/shark_studio/web/ui/common_events.py
Normal file
55
apps/shark_studio/web/ui/common_events.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
HSLHue,
|
||||
hsl_color,
|
||||
)
|
||||
from apps.shark_studio.modules.embeddings import get_lora_metadata
|
||||
|
||||
|
||||
# Answers HTML to show the most frequent tags used when a LoRA was trained,
|
||||
# taken from the metadata of its .safetensors file.
|
||||
def lora_changed(lora_file):
|
||||
# tag frequency percentage, that gets maximum amount of the staring hue
|
||||
TAG_COLOR_THRESHOLD = 0.55
|
||||
# tag frequency percentage, above which a tag is displayed
|
||||
TAG_DISPLAY_THRESHOLD = 0.65
|
||||
# template for the html used to display a tag
|
||||
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
|
||||
|
||||
if lora_file == "None":
|
||||
return ["<div><i>No LoRA selected</i></div>"]
|
||||
elif not lora_file.lower().endswith(".safetensors"):
|
||||
return [
|
||||
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
|
||||
]
|
||||
else:
|
||||
metadata = get_lora_metadata(lora_file)
|
||||
if metadata:
|
||||
frequencies = metadata["frequencies"]
|
||||
return [
|
||||
"".join(
|
||||
[
|
||||
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
|
||||
]
|
||||
+ [
|
||||
TAG_HTML_TEMPLATE.format(
|
||||
color=hsl_color(
|
||||
(tag[1] - TAG_COLOR_THRESHOLD)
|
||||
/ (1 - TAG_COLOR_THRESHOLD),
|
||||
start=HSLHue.RED,
|
||||
end=HSLHue.GREEN,
|
||||
),
|
||||
tag=tag[0],
|
||||
)
|
||||
for tag in frequencies
|
||||
if tag[1] > TAG_DISPLAY_THRESHOLD
|
||||
],
|
||||
)
|
||||
]
|
||||
elif metadata is None:
|
||||
return [
|
||||
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
|
||||
]
|
||||
324
apps/shark_studio/web/ui/css/sd_dark_theme.css
Normal file
324
apps/shark_studio/web/ui/css/sd_dark_theme.css
Normal file
@@ -0,0 +1,324 @@
|
||||
/*
|
||||
Apply Gradio dark theme to the default Gradio theme.
|
||||
Procedure to upgrade the dark theme:
|
||||
- Using your browser, visit http://localhost:8080/?__theme=dark
|
||||
- Open your browser inspector, search for the .dark css class
|
||||
- Copy .dark class declarations, apply them here into :root
|
||||
*/
|
||||
|
||||
:root {
|
||||
--body-background-fill: var(--background-fill-primary);
|
||||
--body-text-color: var(--neutral-100);
|
||||
--color-accent-soft: var(--neutral-700);
|
||||
--background-fill-primary: var(--neutral-950);
|
||||
--background-fill-secondary: var(--neutral-900);
|
||||
--border-color-accent: var(--neutral-600);
|
||||
--border-color-primary: var(--neutral-700);
|
||||
--link-text-color-active: var(--secondary-500);
|
||||
--link-text-color: var(--secondary-500);
|
||||
--link-text-color-hover: var(--secondary-400);
|
||||
--link-text-color-visited: var(--secondary-600);
|
||||
--body-text-color-subdued: var(--neutral-400);
|
||||
--shadow-spread: 1px;
|
||||
--block-background-fill: var(--neutral-800);
|
||||
--block-border-color: var(--border-color-primary);
|
||||
--block_border_width: None;
|
||||
--block-info-text-color: var(--body-text-color-subdued);
|
||||
--block-label-background-fill: var(--background-fill-secondary);
|
||||
--block-label-border-color: var(--border-color-primary);
|
||||
--block_label_border_width: None;
|
||||
--block-label-text-color: var(--neutral-200);
|
||||
--block_shadow: None;
|
||||
--block_title_background_fill: None;
|
||||
--block_title_border_color: None;
|
||||
--block_title_border_width: None;
|
||||
--block-title-text-color: var(--neutral-200);
|
||||
--panel-background-fill: var(--background-fill-secondary);
|
||||
--panel-border-color: var(--border-color-primary);
|
||||
--panel_border_width: None;
|
||||
--checkbox-background-color: var(--neutral-800);
|
||||
--checkbox-background-color-focus: var(--checkbox-background-color);
|
||||
--checkbox-background-color-hover: var(--checkbox-background-color);
|
||||
--checkbox-background-color-selected: var(--secondary-600);
|
||||
--checkbox-border-color: var(--neutral-700);
|
||||
--checkbox-border-color-focus: var(--secondary-500);
|
||||
--checkbox-border-color-hover: var(--neutral-600);
|
||||
--checkbox-border-color-selected: var(--secondary-600);
|
||||
--checkbox-border-width: var(--input-border-width);
|
||||
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
|
||||
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
|
||||
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
|
||||
--checkbox-label-border-color: var(--border-color-primary);
|
||||
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
|
||||
--checkbox-label-border-width: var(--input-border-width);
|
||||
--checkbox-label-text-color: var(--body-text-color);
|
||||
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
|
||||
--error-background-fill: var(--background-fill-primary);
|
||||
--error-border-color: var(--border-color-primary);
|
||||
--error_border_width: None;
|
||||
--error-text-color: #ef4444;
|
||||
--input-background-fill: var(--neutral-800);
|
||||
--input-background-fill-focus: var(--secondary-600);
|
||||
--input-background-fill-hover: var(--input-background-fill);
|
||||
--input-border-color: var(--border-color-primary);
|
||||
--input-border-color-focus: var(--neutral-700);
|
||||
--input-border-color-hover: var(--input-border-color);
|
||||
--input_border_width: None;
|
||||
--input-placeholder-color: var(--neutral-500);
|
||||
--input_shadow: None;
|
||||
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
|
||||
--loader_color: None;
|
||||
--slider_color: None;
|
||||
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
|
||||
--table-border-color: var(--neutral-700);
|
||||
--table-even-background-fill: var(--neutral-950);
|
||||
--table-odd-background-fill: var(--neutral-900);
|
||||
--table-row-focus: var(--color-accent-soft);
|
||||
--button-border-width: var(--input-border-width);
|
||||
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
|
||||
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
|
||||
--button-cancel-border-color: #dc2626;
|
||||
--button-cancel-border-color-hover: var(--button-cancel-border-color);
|
||||
--button-cancel-text-color: white;
|
||||
--button-cancel-text-color-hover: var(--button-cancel-text-color);
|
||||
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
|
||||
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
|
||||
--button-primary-border-color: var(--primary-500);
|
||||
--button-primary-border-color-hover: var(--button-primary-border-color);
|
||||
--button-primary-text-color: white;
|
||||
--button-primary-text-color-hover: var(--button-primary-text-color);
|
||||
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
|
||||
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
|
||||
--button-secondary-border-color: var(--neutral-600);
|
||||
--button-secondary-border-color-hover: var(--button-secondary-border-color);
|
||||
--button-secondary-text-color: white;
|
||||
--button-secondary-text-color-hover: var(--button-secondary-text-color);
|
||||
--block-border-width: 1px;
|
||||
--block-label-border-width: 1px;
|
||||
--form-gap-width: 1px;
|
||||
--error-border-width: 1px;
|
||||
--input-border-width: 1px;
|
||||
}
|
||||
|
||||
/* SHARK theme */
|
||||
body {
|
||||
background-color: var(--background-fill-primary);
|
||||
}
|
||||
|
||||
.generating.svelte-zlszon.svelte-zlszon {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.generating {
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
#chatbot {
|
||||
height: 100% !important;
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
@media (min-width: 1536px)
|
||||
{
|
||||
.gradio-container {
|
||||
max-width: var(--size-full) !important;
|
||||
}
|
||||
}
|
||||
|
||||
.gradio-container .contain {
|
||||
padding: 0 var(--size-4) !important;
|
||||
}
|
||||
|
||||
#top_logo {
|
||||
color: transparent;
|
||||
background-color: transparent;
|
||||
border-radius: 0 !important;
|
||||
border: 0;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: var(--size-2) 0 0 var(--size-1);
|
||||
}
|
||||
|
||||
#demo_title_outer {
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
#prompt_box_outer div:first-child {
|
||||
border-radius: 0 !important
|
||||
}
|
||||
|
||||
#prompt_box textarea, #negative_prompt_box textarea {
|
||||
background-color: var(--background-fill-primary) !important;
|
||||
}
|
||||
|
||||
#prompt_examples {
|
||||
margin: 0 !important;
|
||||
}
|
||||
|
||||
#prompt_examples svg {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
#ui_body {
|
||||
padding: var(--size-2) !important;
|
||||
border-radius: 0.5em !important;
|
||||
}
|
||||
|
||||
#img_result+div {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
footer {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
#gallery + div {
|
||||
border-radius: 0 !important;
|
||||
}
|
||||
|
||||
/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */
|
||||
#gallery .thumbnail-item.thumbnail-lg {
|
||||
aspect-ratio: unset;
|
||||
max-height: calc(55vh - (2 * var(--spacing-lg)));
|
||||
}
|
||||
@media (min-width: 1921px) {
|
||||
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
|
||||
#gallery .grid-wrap, #gallery .preview{
|
||||
min-height: calc(768px + 4px + var(--size-14));
|
||||
max-height: calc(768px + 4px + var(--size-14));
|
||||
}
|
||||
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
|
||||
#gallery .thumbnail-item.thumbnail-lg {
|
||||
max-height: 770px !important;
|
||||
}
|
||||
}
|
||||
/* Don't upscale when viewing in solo image mode */
|
||||
#gallery .preview img {
|
||||
object-fit: scale-down;
|
||||
}
|
||||
/* Navbar images in cover mode*/
|
||||
#gallery .preview .thumbnail-item img {
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
/* Limit the stable diffusion text output height */
|
||||
#std_output textarea {
|
||||
max-height: 215px;
|
||||
}
|
||||
|
||||
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
|
||||
#gallery .wrap.default {
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
/* Import Png info box */
|
||||
#txt2img_prompt_image {
|
||||
height: var(--size-32) !important;
|
||||
}
|
||||
|
||||
/* Hide "remove buttons" from ui dropdowns */
|
||||
#custom_model .token-remove.remove-all,
|
||||
#lora_weights .token-remove.remove-all,
|
||||
#scheduler .token-remove.remove-all,
|
||||
#device .token-remove.remove-all,
|
||||
#stencil_model .token-remove.remove-all {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Hide selected items from ui dropdowns */
|
||||
#custom_model .options .item .inner-item,
|
||||
#scheduler .options .item .inner-item,
|
||||
#device .options .item .inner-item,
|
||||
#stencil_model .options .item .inner-item {
|
||||
display:none;
|
||||
}
|
||||
|
||||
/* workarounds for container=false not currently working for dropdowns */
|
||||
.dropdown_no_container {
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
#output_subdir_container :first-child {
|
||||
border: none;
|
||||
}
|
||||
|
||||
/* reduced animation load when generating */
|
||||
.generating {
|
||||
animation-play-state: paused !important;
|
||||
}
|
||||
|
||||
/* better clarity when progress bars are minimal */
|
||||
.meta-text {
|
||||
background-color: var(--block-label-background-fill);
|
||||
}
|
||||
|
||||
/* lora tag pills */
|
||||
.lora-tags {
|
||||
border: 1px solid var(--border-color-primary);
|
||||
color: var(--block-info-text-color) !important;
|
||||
padding: var(--block-padding);
|
||||
}
|
||||
|
||||
.lora-tag {
|
||||
display: inline-block;
|
||||
height: 2em;
|
||||
color: rgb(212 212 212) !important;
|
||||
margin-right: 5pt;
|
||||
margin-bottom: 5pt;
|
||||
padding: 2pt 5pt;
|
||||
border-radius: 5pt;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.lora-model {
|
||||
margin-bottom: var(--spacing-lg);
|
||||
color: var(--block-info-text-color) !important;
|
||||
line-height: var(--line-sm);
|
||||
}
|
||||
|
||||
/* output gallery tab */
|
||||
.output_parameters_dataframe table.table {
|
||||
/* works around a gradio bug that always shows scrollbars */
|
||||
overflow: clip auto;
|
||||
}
|
||||
|
||||
.output_parameters_dataframe tbody td {
|
||||
font-size: small;
|
||||
line-height: var(--line-xs);
|
||||
}
|
||||
|
||||
.output_icon_button {
|
||||
max-width: 30px;
|
||||
align-self: end;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
|
||||
.outputgallery_sendto {
|
||||
min-width: 7em !important;
|
||||
}
|
||||
|
||||
/* output gallery should take up most of the viewport height regardless of image size/number */
|
||||
#outputgallery_gallery .fixed-height {
|
||||
min-height: 89vh !important;
|
||||
}
|
||||
|
||||
/* don't stretch non-square images to be square, breaking their aspect ratio */
|
||||
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
|
||||
object-fit: contain !important;
|
||||
}
|
||||
|
||||
/* centered logo for when there are no images */
|
||||
#top_logo.logo_centered {
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#top_logo.logo_centered img{
|
||||
object-fit: scale-down;
|
||||
position: absolute;
|
||||
width: 80%;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
}
|
||||
BIN
apps/shark_studio/web/ui/logos/nod-icon.png
Normal file
BIN
apps/shark_studio/web/ui/logos/nod-icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
BIN
apps/shark_studio/web/ui/logos/nod-logo.png
Normal file
BIN
apps/shark_studio/web/ui/logos/nod-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
416
apps/shark_studio/web/ui/outputgallery.py
Normal file
416
apps/shark_studio/web/ui/outputgallery.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import glob
|
||||
import gradio as gr
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from PIL import Image
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.api.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.shark_studio.web.ui.utils import nodlogo_loc
|
||||
from apps.shark_studio.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
|
||||
output_dir = get_generated_imgs_path()
|
||||
|
||||
|
||||
def outputgallery_filenames(subdir) -> list[str]:
|
||||
new_dir_path = os.path.join(output_dir, subdir)
|
||||
if os.path.exists(new_dir_path):
|
||||
filenames = [
|
||||
glob.glob(new_dir_path + "/" + ext)
|
||||
for ext in ("*.png", "*.jpg", "*.jpeg")
|
||||
]
|
||||
|
||||
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def output_subdirs() -> list[str]:
|
||||
# Gets a list of subdirectories of output_dir and below, as relative paths.
|
||||
relative_paths = [
|
||||
os.path.relpath(entry[0], output_dir)
|
||||
for entry in os.walk(
|
||||
output_dir, followlinks=cmd_opts.output_gallery_followlinks
|
||||
)
|
||||
]
|
||||
|
||||
# It is less confusing to always including the subdir that will take any
|
||||
# images generated today even if it doesn't exist yet
|
||||
if get_generated_imgs_todays_subdir() not in relative_paths:
|
||||
relative_paths.append(get_generated_imgs_todays_subdir())
|
||||
|
||||
# sort subdirectories so that the date named ones we probably
|
||||
# created in this or previous sessions come first, sorted with the most
|
||||
# recent first. Other subdirs are listed after.
|
||||
generated_paths = sorted(
|
||||
[path for path in relative_paths if path.isnumeric()], reverse=True
|
||||
)
|
||||
result_paths = generated_paths + sorted(
|
||||
[
|
||||
path
|
||||
for path in relative_paths
|
||||
if (not path.isnumeric()) and path != "."
|
||||
]
|
||||
)
|
||||
|
||||
return result_paths
|
||||
|
||||
|
||||
# --- Define UI layout for Gradio
|
||||
|
||||
with gr.Blocks() as outputgallery_element:
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
|
||||
with gr.Row(elem_id="outputgallery_gallery"):
|
||||
# needed to workaround gradio issue:
|
||||
# https://github.com/gradio-app/gradio/issues/2907
|
||||
dev_null = gr.Textbox("", visible=False)
|
||||
|
||||
gallery_files = gr.State(value=[])
|
||||
subdirectory_paths = gr.State(value=[])
|
||||
|
||||
with gr.Column(scale=6):
|
||||
logo = gr.Image(
|
||||
label="Getting subdirectories...",
|
||||
value=nod_logo,
|
||||
interactive=False,
|
||||
visible=True,
|
||||
show_label=True,
|
||||
elem_id="top_logo",
|
||||
elem_classes="logo_centered",
|
||||
show_download_button=False,
|
||||
)
|
||||
|
||||
gallery = gr.Gallery(
|
||||
label="",
|
||||
value=gallery_files.value,
|
||||
visible=False,
|
||||
show_label=True,
|
||||
columns=4,
|
||||
)
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
with gr.Column(
|
||||
scale=15,
|
||||
min_width=160,
|
||||
elem_id="output_subdir_container",
|
||||
):
|
||||
subdirectories = gr.Dropdown(
|
||||
label=f"Subdirectories of {output_dir}",
|
||||
type="value",
|
||||
choices=subdirectory_paths.value,
|
||||
value="",
|
||||
interactive=True,
|
||||
elem_classes="dropdown_no_container",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
min_width=32,
|
||||
elem_classes="output_icon_button",
|
||||
):
|
||||
open_subdir = gr.Button(
|
||||
variant="secondary",
|
||||
value="\U0001F5C1", # unicode open folder
|
||||
interactive=False,
|
||||
size="sm",
|
||||
)
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
min_width=32,
|
||||
elem_classes="output_icon_button",
|
||||
):
|
||||
refresh = gr.Button(
|
||||
variant="secondary",
|
||||
value="\u21BB", # unicode clockwise arrow circle
|
||||
size="sm",
|
||||
)
|
||||
|
||||
image_columns = gr.Slider(
|
||||
label="Columns shown", value=4, minimum=1, maximum=16, step=1
|
||||
)
|
||||
outputgallery_filename = gr.Textbox(
|
||||
label="Filename",
|
||||
value="None",
|
||||
interactive=False,
|
||||
show_copy_button=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(
|
||||
label="Parameter Information", open=False
|
||||
) as parameters_accordian:
|
||||
image_parameters = gr.DataFrame(
|
||||
headers=["Parameter", "Value"],
|
||||
col_count=2,
|
||||
wrap=True,
|
||||
elem_classes="output_parameters_dataframe",
|
||||
value=[["Status", "No image selected"]],
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Send To", open=True):
|
||||
with gr.Row():
|
||||
outputgallery_sendto_sd = gr.Button(
|
||||
value="Stable Diffusion",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
|
||||
# --- Event handlers
|
||||
|
||||
def on_clear_gallery():
|
||||
return [
|
||||
gr.Gallery(
|
||||
value=[],
|
||||
visible=False,
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
|
||||
def on_image_columns_change(columns):
|
||||
return gr.Gallery(columns=columns)
|
||||
|
||||
def on_select_subdir(subdir) -> list:
|
||||
# evt.value is the subdirectory name
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = (
|
||||
f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
|
||||
)
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_open_subdir(subdir):
|
||||
subdir_path = os.path.normpath(os.path.join(output_dir, subdir))
|
||||
|
||||
if os.path.isdir(subdir_path):
|
||||
if sys.platform == "linux":
|
||||
subprocess.run(["xdg-open", subdir_path])
|
||||
elif sys.platform == "darwin":
|
||||
subprocess.run(["open", subdir_path])
|
||||
elif sys.platform == "win32":
|
||||
os.startfile(subdir_path)
|
||||
|
||||
def on_refresh(current_subdir: str) -> list:
|
||||
# get an up-to-date subdirectory list
|
||||
refreshed_subdirs = output_subdirs()
|
||||
# get the images using either the current subdirectory or the most
|
||||
# recent valid one
|
||||
new_subdir = (
|
||||
current_subdir
|
||||
if current_subdir in refreshed_subdirs
|
||||
else refreshed_subdirs[0]
|
||||
)
|
||||
new_images = outputgallery_filenames(new_subdir)
|
||||
new_label = (
|
||||
f"{len(new_images)} images in "
|
||||
f"{os.path.join(output_dir, new_subdir)}"
|
||||
)
|
||||
|
||||
return [
|
||||
gr.Dropdown(
|
||||
choices=refreshed_subdirs,
|
||||
value=new_subdir,
|
||||
),
|
||||
refreshed_subdirs,
|
||||
new_images,
|
||||
gr.Gallery(
|
||||
value=new_images, label=new_label, visible=len(new_images) > 0
|
||||
),
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_new_image(subdir, subdir_paths, status) -> list:
|
||||
# prevent error triggered when an image generates before the tab
|
||||
# has even been selected
|
||||
subdir_paths = (
|
||||
subdir_paths
|
||||
if len(subdir_paths) > 0
|
||||
else [get_generated_imgs_todays_subdir()]
|
||||
)
|
||||
|
||||
# only update if the current subdir is the most recent one as
|
||||
# new images only go there
|
||||
if subdir_paths[0] == subdir:
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = (
|
||||
f"{len(new_images)} images in "
|
||||
f"{os.path.join(output_dir, subdir)} - {status}"
|
||||
)
|
||||
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
else:
|
||||
# otherwise change nothing,
|
||||
# (only untyped gradio gr.update() does this)
|
||||
return [gr.update(), gr.update(), gr.update()]
|
||||
|
||||
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
|
||||
# evt.index is an index into the full list of filenames for
|
||||
# the current subdirectory
|
||||
filename = images[evt.index]
|
||||
params = displayable_metadata(filename)
|
||||
|
||||
if params:
|
||||
if params["source"] == "missing":
|
||||
return [
|
||||
"Could not find this image file, refresh the gallery and update the images",
|
||||
[["Status", "File missing"]],
|
||||
]
|
||||
else:
|
||||
return [
|
||||
filename,
|
||||
list(map(list, params["parameters"].items())),
|
||||
]
|
||||
|
||||
return [
|
||||
filename,
|
||||
[["Status", "No parameters found"]],
|
||||
]
|
||||
|
||||
def on_outputgallery_filename_change(filename: str) -> list:
|
||||
exists = filename != "None" and os.path.exists(filename)
|
||||
return [
|
||||
# disable or enable each of the sendto button based on whether
|
||||
# an image is selected
|
||||
gr.Button(interactive=exists),
|
||||
]
|
||||
|
||||
# The time first our tab is selected we need to do an initial refresh
|
||||
# to populate the subdirectory select box and the images from the most
|
||||
# recent subdirectory.
|
||||
#
|
||||
# We do it at this point rather than setting this up in the controls'
|
||||
# definitions as when you refresh the browser you always get what was
|
||||
# *initially* set, which won't include any new subdirectories or images
|
||||
# that might have created since the application was started. Doing it
|
||||
# this way means a browser refresh/reload always gets the most
|
||||
# up-to-date data.
|
||||
def on_select_tab(subdir_paths, request: gr.Request):
|
||||
local_client = request.headers["host"].startswith(
|
||||
"127.0.0.1:"
|
||||
) or request.headers["host"].startswith("localhost:")
|
||||
|
||||
if len(subdir_paths) == 0:
|
||||
return on_refresh("") + [gr.update(interactive=local_client)]
|
||||
else:
|
||||
return (
|
||||
# Change nothing, (only untyped gr.update() does this)
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# clearing images when we need to completely change what's in the
|
||||
# gallery avoids current images being shown replacing piecemeal and
|
||||
# prevents weirdness and errors if the user selects an image during the
|
||||
# replacement phase.
|
||||
clear_gallery = dict(
|
||||
fn=on_clear_gallery,
|
||||
inputs=None,
|
||||
outputs=[gallery, logo],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
subdirectories.select(**clear_gallery).then(
|
||||
on_select_subdir,
|
||||
[subdirectories],
|
||||
[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
|
||||
|
||||
refresh.click(**clear_gallery).then(
|
||||
on_refresh,
|
||||
[subdirectories],
|
||||
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
image_columns.change(
|
||||
fn=on_image_columns_change,
|
||||
inputs=[image_columns],
|
||||
outputs=[gallery],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
gallery.select(
|
||||
on_select_image,
|
||||
[gallery_files],
|
||||
[outputgallery_filename, image_parameters],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
outputgallery_filename.change(
|
||||
on_outputgallery_filename_change,
|
||||
[outputgallery_filename],
|
||||
[
|
||||
outputgallery_sendto_sd,
|
||||
],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# We should have been given the .select function for our tab, so set it up
|
||||
def outputgallery_tab_select(select):
|
||||
select(
|
||||
fn=on_select_tab,
|
||||
inputs=[subdirectory_paths],
|
||||
outputs=[
|
||||
subdirectories,
|
||||
subdirectory_paths,
|
||||
gallery_files,
|
||||
gallery,
|
||||
logo,
|
||||
open_subdir,
|
||||
],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# We should have been passed a list of components on other tabs that update
|
||||
# when a new image has generated on that tab, so set things up so the user
|
||||
# will see that new image if they are looking at today's subdirectory
|
||||
def outputgallery_watch(components: gr.Textbox):
|
||||
for component in components:
|
||||
component.change(
|
||||
on_new_image,
|
||||
inputs=[subdirectories, subdirectory_paths, component],
|
||||
outputs=[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
)
|
||||
@@ -24,130 +24,31 @@ from apps.shark_studio.api.utils import (
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
sd_model_map,
|
||||
StableDiffusion,
|
||||
)
|
||||
from apps.shark_studio.api.schedulers import (
|
||||
scheduler_model_map,
|
||||
shark_sd_fn,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.shark_studio.api.controlnet import (
|
||||
preprocessor_model_map,
|
||||
control_adapter_model_map,
|
||||
PreprocessorModel,
|
||||
cnet_preview,
|
||||
)
|
||||
from apps.shark_studio.modules.schedulers import (
|
||||
scheduler_model_map,
|
||||
)
|
||||
from apps.shark_studio.modules.img_processing import (
|
||||
resampler_list,
|
||||
resize_stencil,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
get_generation_text_info,
|
||||
nodlogo_loc,
|
||||
)
|
||||
from apps.shark_studio.web.utils.state import (
|
||||
get_generation_text_info,
|
||||
status_label,
|
||||
)
|
||||
from apps.shark_studio.web.ui.common_events import lora_changed
|
||||
|
||||
sd_pipe = None
|
||||
|
||||
|
||||
# NOTE: Each `hf_model_id` should have its own starting configuration.
|
||||
|
||||
# model_vmfb_key = ""
|
||||
|
||||
def shark_sd_fn(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image_dict,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
base_model_id: str,
|
||||
custom_checkpoints: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
lora_weights: str | list,
|
||||
lora_hf_ids: str | list,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
control_mode: str,
|
||||
stencils: list,
|
||||
images: list,
|
||||
preprocessed_hints: list,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
|
||||
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
|
||||
for i, stencil in enumerate(stencils):
|
||||
if images[i] is None and stencil is not None:
|
||||
continue
|
||||
elif stencil is None and any(img is not None for img in [images[i], preprocessed_hints[i]]):
|
||||
images[i] = None
|
||||
preprocessed_hints[i] = None
|
||||
elif images[i] is not None:
|
||||
if isinstance(images[i], dict):
|
||||
images[i] = images[i]["composite"]
|
||||
images[i] = images[i].convert("RGB")
|
||||
|
||||
if isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
elif image_dict:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
if image:
|
||||
image, _, _, = resize_stencil(image, width, height)
|
||||
|
||||
device_id = None
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
submit_pipe_kwargs = {
|
||||
base_model_id: base_model_id,
|
||||
height: height,
|
||||
width: width,
|
||||
precision: precision,
|
||||
device: device,
|
||||
extra_model_ids: extra_model_ids,
|
||||
embeddings: lora_hf_ids,
|
||||
import_ir: cmd_opts.import_ir,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
|
||||
|
||||
|
||||
global sd_pipe
|
||||
global sd_pipe_kwargs
|
||||
|
||||
for key in
|
||||
|
||||
if sd_pipe is None:
|
||||
history[-1][-1] = "Getting the pipeline ready..."
|
||||
yield history, ""
|
||||
|
||||
# Initializes the pipeline and retrieves IR based on all
|
||||
# parameters that are static in the turbine output format,
|
||||
# which is currently MLIR in the torch dialect.
|
||||
|
||||
sd_pipe = SharkStableDiffusionPipeline(
|
||||
**submit_pipe_kwargs
|
||||
)
|
||||
sd_pipe.queue_compile()
|
||||
|
||||
for prompt, msg, exec_time in progress.tqdm(
|
||||
sd_pipe.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
),
|
||||
desc="Generating Image...",
|
||||
):
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def view_json_file(file_obj):
|
||||
content = ""
|
||||
@@ -155,17 +56,33 @@ def view_json_file(file_obj):
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
sd_fn_sig = signature(shark_sd_fn)
|
||||
max_controlnets = 5
|
||||
|
||||
max_controlnets = 3
|
||||
max_loras = 5
|
||||
|
||||
|
||||
def show_loras(k):
|
||||
k = int(k)
|
||||
return [gr.Dropdown(visible=True)]*k + [gr.Dropdown(visible=False, value="None")]*(max_textboxes-k)
|
||||
return gr.State(
|
||||
[gr.Dropdown(visible=True)] * k
|
||||
+ [gr.Dropdown(visible=False, value="None")] * (max_loras - k)
|
||||
)
|
||||
|
||||
|
||||
def show_controlnets(k):
|
||||
k = int(k)
|
||||
return [gr.Row(visible=True)]*k + [gr.Row(visible=False)]*(max_textboxes-k)
|
||||
return [
|
||||
gr.State(
|
||||
[
|
||||
[gr.Row(visible=True, render=True)] * k
|
||||
+ [gr.Row(visible=False)] * (max_controlnets - k)
|
||||
]
|
||||
),
|
||||
gr.State([None] * k),
|
||||
gr.State([None] * k),
|
||||
gr.State([None] * k),
|
||||
]
|
||||
|
||||
|
||||
def create_canvas(width, height):
|
||||
data = Image.fromarray(
|
||||
@@ -182,10 +99,9 @@ def create_canvas(width, height):
|
||||
}
|
||||
return EditorValue(img_dict)
|
||||
|
||||
|
||||
def import_original(original_img, width, height):
|
||||
resized_img, _, _ = resize_stencil(
|
||||
original_img, width, height
|
||||
)
|
||||
resized_img, _, _ = resize_stencil(original_img, width, height)
|
||||
img_dict = {
|
||||
"background": resized_img,
|
||||
"layers": [resized_img],
|
||||
@@ -196,6 +112,7 @@ def import_original(original_img, width, height):
|
||||
crop_size=(width, height),
|
||||
)
|
||||
|
||||
|
||||
def update_cn_input(
|
||||
model,
|
||||
width,
|
||||
@@ -203,7 +120,6 @@ def update_cn_input(
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
index,
|
||||
):
|
||||
if model == None:
|
||||
stencils[index] = None
|
||||
@@ -271,80 +187,99 @@ def update_cn_input(
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
|
||||
|
||||
sd_fn_inputs = []
|
||||
sd_fn_sig = signature(shark_sd_fn).replace()
|
||||
for i in sd_fn_sig.parameters:
|
||||
sd_fn_inputs.append(i)
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
# Get a list of arguments needed for the API call, then
|
||||
# initialize an empty list that will manage the corresponding
|
||||
# gradio values.
|
||||
inputs_list = gr.State(signature(shark_sd_fn))
|
||||
inputs_args = gr.State([None] * len(inputs_list))
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
save_sd_config = gr.Button(label="Save Config", scale=1)
|
||||
load_sd_config = gr.FileExplorer("Load Config", scale=1)
|
||||
clear_sd_config = gr.ClearButton("Clear Config", scale=1)
|
||||
with gr.Column(elem_if="ui_body"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row(variant="compact", equal_height=True):
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
elem_id="demo_title_outer",
|
||||
):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Column(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group()
|
||||
sd_model_info = (
|
||||
f"Checkpoint Path: {str(get_checkpoint_path())}"
|
||||
)
|
||||
sd_base = gr.Dropdown(
|
||||
label="Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2.1-base",
|
||||
choices=get_base_models(),
|
||||
) # base_model_id
|
||||
sd_checkpoint = gr.Dropdown(
|
||||
label="Checkpoints (optional)",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
choices=get_checkpoints(sd_base),
|
||||
) #
|
||||
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
|
||||
"\\", "\n\\"
|
||||
)
|
||||
sd_vae_info = f"VAE Path: {sd_vae_info}"
|
||||
sd_custom_vae = gr.Dropdown(
|
||||
label=f"Custom VAE Models",
|
||||
info=sd_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(cmd_opts.custom_vae)
|
||||
if cmd_opts.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_checkpoints("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=3):
|
||||
sd_model_info = (
|
||||
f"Checkpoint Path: {str(get_checkpoints_path())}"
|
||||
)
|
||||
sd_base = gr.Dropdown(
|
||||
label="Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2-1-base",
|
||||
choices=sd_model_map.keys(),
|
||||
) # base_model_id
|
||||
sd_custom_weights = gr.Dropdown(
|
||||
label="Weights (Optional)",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
allow_custom_value=True,
|
||||
choices=get_checkpoints(sd_base),
|
||||
) #
|
||||
with gr.Column(scale=2):
|
||||
sd_vae_info = (
|
||||
str(get_checkpoints_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
sd_vae_info = f"VAE Path: {sd_vae_info}"
|
||||
sd_custom_vae = gr.Dropdown(
|
||||
label=f"Custom VAE Models",
|
||||
info=sd_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(cmd_opts.custom_vae)
|
||||
if cmd_opts.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_checkpoints("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
save_sd_config = gr.Button(
|
||||
value="Save Config", size="sm"
|
||||
)
|
||||
clear_sd_config = gr.ClearButton(
|
||||
value="Clear Config", size="sm"
|
||||
)
|
||||
load_sd_config = gr.FileExplorer(
|
||||
label="Load Config",
|
||||
root=os.path.basename("./configs"),
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
value=cmd_opts.prompts[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
value=cmd_opts.negative_prompts[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
with gr.Accordion(label = "Input Image", open=False):
|
||||
|
||||
with gr.Accordion(label="Input Image", open=False):
|
||||
# TODO: make this import image prompt info if it exists
|
||||
sd_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
@@ -352,41 +287,94 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
height=300,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Accordion(label="Embeddings options", open=False):
|
||||
with gr.Accordion(
|
||||
label="Embeddings options", open=False, render=True
|
||||
):
|
||||
sd_lora_info = (
|
||||
str(get_checkpoints_path("loras"))
|
||||
).replace("\\", "\n\\")
|
||||
num_loras = gr.Slider(1, max_loras, value=1, step=1, label="LoRA Count")
|
||||
loras = []
|
||||
num_loras = gr.Slider(
|
||||
1, max_loras, value=1, step=1, label="LoRA Count"
|
||||
)
|
||||
loras = gr.State([])
|
||||
for i in range(max_loras):
|
||||
lora_opt = gr.Dropdown(
|
||||
allow_custom_value=False,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=sd_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
with gr.Row():
|
||||
lora_opt = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=sd_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_checkpoints("lora"),
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
gr.on(
|
||||
triggers=[lora_opt.change],
|
||||
fn=lora_changed,
|
||||
inputs=[lora_opt],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
loras.value.append(lora_opt)
|
||||
|
||||
num_loras.change(show_loras, [num_loras], [loras])
|
||||
with gr.Accordion(label="Advanced Options", open=True):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list,
|
||||
choices=scheduler_model_map.keys(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 768, value=cmd_opts.height, step=8, label="Height"
|
||||
384,
|
||||
768,
|
||||
value=cmd_opts.height,
|
||||
step=8,
|
||||
label="Height",
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 768, value=cmd_opts.width, step=8, label="Width"
|
||||
384,
|
||||
768,
|
||||
value=cmd_opts.width,
|
||||
step=8,
|
||||
label="Width",
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.steps,
|
||||
step=1,
|
||||
label="Steps",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=cmd_opts.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
cmd_opts.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
strength = gr.Slider(
|
||||
@@ -402,6 +390,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=cmd_opts.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
label="Low VRAM",
|
||||
@@ -416,38 +411,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=cmd_opts.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
cmd_opts.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=cmd_opts.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=cmd_opts.seed,
|
||||
@@ -457,40 +420,53 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=get_available_devices[0],
|
||||
choices=get_available_devices,
|
||||
value=get_available_devices()[0],
|
||||
choices=get_available_devices(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Accordion(label="Controlnet Options", open=False):
|
||||
with gr.Accordion(
|
||||
label="Controlnet Options", open=False, render=False
|
||||
):
|
||||
sd_cnet_info = (
|
||||
str(get_checkpoints_path("controlnet"))
|
||||
).replace("\\", "\n\\")
|
||||
num_cnets = gr.Slider(1, max_controlnets, value=1, step=1, label="Controlnet Count")
|
||||
num_cnets = gr.Slider(
|
||||
0,
|
||||
max_controlnets,
|
||||
value=0,
|
||||
step=1,
|
||||
label="Controlnet Count",
|
||||
)
|
||||
cnet_rows = []
|
||||
stencils = []
|
||||
images = []
|
||||
preprocessed_hints = []
|
||||
stencils = gr.State([])
|
||||
images = gr.State([])
|
||||
preprocessed_hints = gr.State([])
|
||||
control_mode = gr.Radio(
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
value="Balanced",
|
||||
label="Control Mode",
|
||||
)
|
||||
|
||||
for i in range(max_controlnets):
|
||||
with gr.Row as cnet_row:
|
||||
with gr.Row(visible=False) as cnet_row:
|
||||
with gr.Column():
|
||||
cnet_gen = gr.Button(
|
||||
value="Preprocess controlnet input",
|
||||
)
|
||||
cnet_processor = gr.Dropdown(
|
||||
cnet_model = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Controlnet Preprocessor",
|
||||
label=f"Controlnet Model",
|
||||
info=sd_cnet_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + controlnet_list + get_custom_model_files("controlnet"),
|
||||
)
|
||||
cnet_adapter = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Controlnet Adapter",
|
||||
info=sd_cnet_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + controlnet_list + get_custom_model_files("controlnet"),
|
||||
choices=[
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
]
|
||||
+ get_checkpoints("controlnet"),
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
@@ -529,14 +505,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
visible=True,
|
||||
label="Preprocessed Hint",
|
||||
interactive=True,
|
||||
show_label=True
|
||||
show_label=True,
|
||||
)
|
||||
use_input_img.click(
|
||||
import_original,
|
||||
[sd_init_image, canvas_width, canvas_height],
|
||||
[cnet_image],
|
||||
[cnet_input],
|
||||
)
|
||||
|
||||
cnet_model.change(
|
||||
fn=update_cn_input,
|
||||
inputs=[
|
||||
@@ -563,7 +538,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
create_canvas,
|
||||
[canvas_width, canvas_height],
|
||||
[
|
||||
cnet_image,
|
||||
cnet_input,
|
||||
],
|
||||
)
|
||||
gr.on(
|
||||
@@ -583,12 +558,16 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
cnet_rows.append(cnet_row)
|
||||
cnet_rows.value.append(cnet_row)
|
||||
|
||||
num_cnets.change(show_controlnets, num_cnets, cnet_rows)
|
||||
num_cnets.change(
|
||||
show_controlnets,
|
||||
[num_cnets],
|
||||
[cnet_rows, stencils, images, preprocessed_hints],
|
||||
)
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
img2img_gallery = gr.Gallery(
|
||||
sd_gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
@@ -596,14 +575,14 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{i2i_model_info}\n"
|
||||
value=f"{sd_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
@@ -631,12 +610,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
batch_size,
|
||||
scheduler,
|
||||
sd_base,
|
||||
sd_checkpoint,
|
||||
sd_custom_weights,
|
||||
sd_custom_vae,
|
||||
precision,
|
||||
device,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
loras,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
@@ -652,13 +630,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
show_progress="minimal" if cmd_opts.progress_bar else "none",
|
||||
show_progress="minimal",
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs),
|
||||
fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=img2img_status,
|
||||
outputs=sd_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
@@ -670,10 +648,3 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,33 @@
|
||||
def nodlogo_loc():
|
||||
return "foo"
|
||||
from enum import IntEnum
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def get_checkpoints_path(model_type: str = None):
|
||||
return "foo"
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_checkpoints():
|
||||
return "foo"
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
nodicon_loc = resource_path("logos/nod-icon.png")
|
||||
|
||||
|
||||
class HSLHue(IntEnum):
|
||||
RED = 0
|
||||
YELLOW = 60
|
||||
GREEN = 120
|
||||
CYAN = 180
|
||||
BLUE = 240
|
||||
MAGENTA = 300
|
||||
|
||||
|
||||
def hsl_color(alpha: float, start, end):
|
||||
b = (end - start) * (alpha if alpha > 0 else 0)
|
||||
result = b + start
|
||||
|
||||
# Return a CSS HSL string
|
||||
return f"hsl({math.floor(result)}, 80%, 35%)"
|
||||
|
||||
74
apps/shark_studio/web/utils/globals.py
Normal file
74
apps/shark_studio/web/utils/globals.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import gc
|
||||
|
||||
"""
|
||||
The global objects include SD pipeline and config.
|
||||
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
|
||||
Also we could avoid memory leak when switching models by clearing the cache.
|
||||
"""
|
||||
|
||||
|
||||
def _init():
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
global _schedulers
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
_schedulers = None
|
||||
|
||||
|
||||
def set_sd_obj(value):
|
||||
global _sd_obj
|
||||
_sd_obj = value
|
||||
|
||||
|
||||
def set_sd_scheduler(key):
|
||||
global _sd_obj
|
||||
_sd_obj.scheduler = _schedulers[key]
|
||||
|
||||
|
||||
def set_sd_status(value):
|
||||
global _sd_obj
|
||||
_sd_obj.status = value
|
||||
|
||||
|
||||
def set_cfg_obj(value):
|
||||
global _config_obj
|
||||
_config_obj = value
|
||||
|
||||
|
||||
def set_schedulers(value):
|
||||
global _schedulers
|
||||
_schedulers = value
|
||||
|
||||
|
||||
def get_sd_obj():
|
||||
global _sd_obj
|
||||
return _sd_obj
|
||||
|
||||
|
||||
def get_sd_status():
|
||||
global _sd_obj
|
||||
return _sd_obj.status
|
||||
|
||||
|
||||
def get_cfg_obj():
|
||||
global _config_obj
|
||||
return _config_obj
|
||||
|
||||
|
||||
def get_scheduler(key):
|
||||
global _schedulers
|
||||
return _schedulers[key]
|
||||
|
||||
|
||||
def clear_cache():
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
global _schedulers
|
||||
del _sd_obj
|
||||
del _config_obj
|
||||
del _schedulers
|
||||
gc.collect()
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
_schedulers = None
|
||||
6
apps/shark_studio/web/utils/metadata/__init__.py
Normal file
6
apps/shark_studio/web/utils/metadata/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .png_metadata import (
|
||||
import_png_metadata,
|
||||
)
|
||||
from .display import (
|
||||
displayable_metadata,
|
||||
)
|
||||
45
apps/shark_studio/web/utils/metadata/csv_metadata.py
Normal file
45
apps/shark_studio/web/utils/metadata/csv_metadata.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import csv
|
||||
import os
|
||||
from .format import humanize, humanizable
|
||||
|
||||
|
||||
def csv_path(image_filename: str):
|
||||
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
|
||||
|
||||
|
||||
def has_csv(image_filename: str) -> bool:
|
||||
return os.path.exists(csv_path(image_filename))
|
||||
|
||||
|
||||
def matching_filename(image_filename: str, row):
|
||||
# we assume the final column of the csv has the original filename with full path and match that
|
||||
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
|
||||
# the value of the OUTPUT key
|
||||
return os.path.basename(image_filename) in (
|
||||
row[-1] if isinstance(row, list) else row["OUTPUT"]
|
||||
)
|
||||
|
||||
|
||||
def parse_csv(image_filename: str):
|
||||
csv_filename = csv_path(image_filename)
|
||||
|
||||
with open(csv_filename, "r", newline="") as csv_file:
|
||||
# We use a reader or DictReader here for images_details.csv depending on whether we think it
|
||||
# has headers or not. Having headers means less guessing of the format.
|
||||
has_header = csv.Sniffer().has_header(csv_file.read(2048))
|
||||
csv_file.seek(0)
|
||||
|
||||
reader = (
|
||||
csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
|
||||
)
|
||||
|
||||
matches = [
|
||||
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
|
||||
humanize(row)
|
||||
for row in reader
|
||||
if row
|
||||
and (has_header or humanizable(row))
|
||||
and matching_filename(image_filename, row)
|
||||
]
|
||||
|
||||
return matches[0] if matches else {}
|
||||
53
apps/shark_studio/web/utils/metadata/display.py
Normal file
53
apps/shark_studio/web/utils/metadata/display.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import json
|
||||
import os
|
||||
from PIL import Image
|
||||
from .png_metadata import parse_generation_parameters
|
||||
from .exif_metadata import has_exif, parse_exif
|
||||
from .csv_metadata import has_csv, parse_csv
|
||||
from .format import compact, humanize
|
||||
|
||||
|
||||
def displayable_metadata(image_filename: str) -> dict:
|
||||
if not os.path.isfile(image_filename):
|
||||
return {"source": "missing", "parameters": {}}
|
||||
|
||||
pil_image = Image.open(image_filename)
|
||||
|
||||
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
|
||||
# and we go via that for SendTo, and is directly tied to the image)
|
||||
if "parameters" in pil_image.info:
|
||||
return {
|
||||
"source": "png",
|
||||
"parameters": compact(
|
||||
parse_generation_parameters(pil_image.info["parameters"])
|
||||
),
|
||||
}
|
||||
|
||||
# we have a matching json file (next most likely to be accurate when it's there)
|
||||
json_path = os.path.splitext(image_filename)[0] + ".json"
|
||||
if os.path.isfile(json_path):
|
||||
with open(json_path) as params_file:
|
||||
return {
|
||||
"source": "json",
|
||||
"parameters": compact(
|
||||
humanize(json.load(params_file), includes_filename=False)
|
||||
),
|
||||
}
|
||||
|
||||
# we have a CSV file so try that (can be different shapes, and it usually has no
|
||||
# headers/param names so of the things we we *know* have parameters, it's the
|
||||
# last resort)
|
||||
if has_csv(image_filename):
|
||||
params = parse_csv(image_filename)
|
||||
if params: # we might not have found the filename in the csv
|
||||
return {
|
||||
"source": "csv",
|
||||
"parameters": compact(params), # already humanized
|
||||
}
|
||||
|
||||
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
|
||||
if has_exif(image_filename):
|
||||
return {"source": "exif", "parameters": parse_exif(pil_image)}
|
||||
|
||||
# we've got nothing
|
||||
return None
|
||||
52
apps/shark_studio/web/utils/metadata/exif_metadata.py
Normal file
52
apps/shark_studio/web/utils/metadata/exif_metadata.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from PIL import Image
|
||||
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
|
||||
|
||||
|
||||
def has_exif(image_filename: str) -> bool:
|
||||
return True if Image.open(image_filename).getexif() else False
|
||||
|
||||
|
||||
def parse_exif(pil_image: Image) -> dict:
|
||||
img_exif = pil_image.getexif()
|
||||
|
||||
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
|
||||
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
|
||||
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
|
||||
# dependency
|
||||
exif_tags = {
|
||||
TAGS.get(key, key): str(val)
|
||||
for (key, val) in img_exif.items()
|
||||
if key in TAGS
|
||||
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
def try_get_ifd(ifd_id):
|
||||
try:
|
||||
return img_exif.get_ifd(ifd_id).items()
|
||||
except KeyError:
|
||||
return {}
|
||||
|
||||
ifd_tags = {
|
||||
TAGS.get(key, key): str(val)
|
||||
for ifd_id in IFD
|
||||
for (key, val) in try_get_ifd(ifd_id)
|
||||
if ifd_id != IFD.GPSInfo
|
||||
and key in TAGS
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
gps_tags = {
|
||||
GPSTAGS.get(key, key): str(val)
|
||||
for (key, val) in try_get_ifd(IFD.GPSInfo)
|
||||
if key in GPSTAGS
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
return {**exif_tags, **ifd_tags, **gps_tags}
|
||||
143
apps/shark_studio/web/utils/metadata/format.py
Normal file
143
apps/shark_studio/web/utils/metadata/format.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# As SHARK has evolved more columns have been added to images_details.csv. However, since
|
||||
# no version of the CSV has any headers (yet) we don't actually have anything within the
|
||||
# file that tells us which parameter each column is for. So this is a list of known patterns
|
||||
# indexed by length which is what we're going to have to use to guess which columns are the
|
||||
# right ones for the file we're looking at.
|
||||
|
||||
# The same ordering is used for JSON, but these do have key names, however they are not very
|
||||
# human friendly, nor do they match up with the what is written to the .png headers
|
||||
|
||||
# So these are functions to try and get something consistent out the raw input from all
|
||||
# these sources
|
||||
|
||||
PARAMS_FORMATS = {
|
||||
9: {
|
||||
"VARIANT": "Model",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
10: {
|
||||
"MODEL": "Model",
|
||||
"VARIANT": "Variant",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
12: {
|
||||
"VARIANT": "Model",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"HEIGHT": "Height",
|
||||
"WIDTH": "Width",
|
||||
"MAX_LENGTH": "Max Length",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
}
|
||||
|
||||
PARAMS_FORMAT_CURRENT = {
|
||||
"VARIANT": "Model",
|
||||
"VAE": "VAE",
|
||||
"LORA": "LoRA",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"HEIGHT": "Height",
|
||||
"WIDTH": "Width",
|
||||
"MAX_LENGTH": "Max Length",
|
||||
"OUTPUT": "Filename",
|
||||
}
|
||||
|
||||
|
||||
def compact(metadata: dict) -> dict:
|
||||
# we don't want to alter the original dictionary
|
||||
result = dict(metadata)
|
||||
|
||||
# discard the filename because we should already have it
|
||||
if result.keys() & {"Filename"}:
|
||||
result.pop("Filename")
|
||||
|
||||
# make showing the sizes more compact by using only one line each
|
||||
if result.keys() & {"Size-1", "Size-2"}:
|
||||
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
|
||||
elif result.keys() & {"Height", "Width"}:
|
||||
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
|
||||
|
||||
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
|
||||
hires_y = result.pop("Hires resize-1")
|
||||
hires_x = result.pop("Hires resize-2")
|
||||
|
||||
if hires_x == 0 and hires_y == 0:
|
||||
result["Hires resize"] = "None"
|
||||
else:
|
||||
result["Hires resize"] = f"{hires_y}x{hires_x}"
|
||||
|
||||
# remove VAE if it exists and is empty
|
||||
if (result.keys() & {"VAE"}) and (
|
||||
not result["VAE"] or result["VAE"] == "None"
|
||||
):
|
||||
result.pop("VAE")
|
||||
|
||||
# remove LoRA if it exists and is empty
|
||||
if (result.keys() & {"LoRA"}) and (
|
||||
not result["LoRA"] or result["LoRA"] == "None"
|
||||
):
|
||||
result.pop("LoRA")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
lookup_key = len(metadata) + (0 if includes_filename else 1)
|
||||
return lookup_key in PARAMS_FORMATS.keys()
|
||||
|
||||
|
||||
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
lookup_key = len(metadata) + (0 if includes_filename else 1)
|
||||
|
||||
# For lists we can only work based on the length, we have no other information
|
||||
if isinstance(metadata, list):
|
||||
if humanizable(metadata, includes_filename):
|
||||
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
|
||||
else:
|
||||
raise KeyError(
|
||||
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
|
||||
)
|
||||
|
||||
# For dictionaries we try to use the matching length parameter format if
|
||||
# available, otherwise we just use the current format which is assumed to
|
||||
# have everything currently known about. Then we swap keys in the metadata
|
||||
# that match keys in the format for the friendlier name that we have set
|
||||
# in the format value
|
||||
if isinstance(metadata, dict):
|
||||
if humanizable(metadata, includes_filename):
|
||||
format = PARAMS_FORMATS[lookup_key]
|
||||
else:
|
||||
format = PARAMS_FORMAT_CURRENT
|
||||
|
||||
return {
|
||||
format[key]: metadata[key]
|
||||
for key in format.keys()
|
||||
if key in metadata.keys() and metadata[key]
|
||||
}
|
||||
|
||||
raise TypeError("Can only humanize parameter lists or dictionaries")
|
||||
222
apps/shark_studio/web/utils/metadata/png_metadata.py
Normal file
222
apps/shark_studio/web/utils/metadata/png_metadata.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from apps.shark_studio.api.utils import (
|
||||
get_checkpoint_pathfile,
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
sd_model_map,
|
||||
)
|
||||
from apps.shark_studio.modules.schedulers import (
|
||||
scheduler_model_map,
|
||||
)
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
res = {}
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
done_with_prompt = False
|
||||
|
||||
*lines, lastline = x.strip().split("\n")
|
||||
if len(re_param.findall(lastline)) < 3:
|
||||
lines.append(lastline)
|
||||
lastline = ""
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if line.startswith("Negative prompt:"):
|
||||
done_with_prompt = True
|
||||
line = line[16:].strip()
|
||||
|
||||
if done_with_prompt:
|
||||
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
||||
else:
|
||||
prompt += ("" if prompt == "" else "\n") + line
|
||||
|
||||
res["Prompt"] = prompt
|
||||
res["Negative prompt"] = negative_prompt
|
||||
|
||||
for k, v in re_param.findall(lastline):
|
||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
res[k + "-1"] = m.group(1)
|
||||
res[k + "-2"] = m.group(2)
|
||||
else:
|
||||
res[k] = v
|
||||
|
||||
# Missing CLIP skip means it was set to 1 (the default)
|
||||
if "Clip skip" not in res:
|
||||
res["Clip skip"] = "1"
|
||||
|
||||
hypernet = res.get("Hypernet", None)
|
||||
if hypernet is not None:
|
||||
res[
|
||||
"Prompt"
|
||||
] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
||||
|
||||
if "Hires resize-1" not in res:
|
||||
res["Hires resize-1"] = 0
|
||||
res["Hires resize-2"] = 0
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def try_find_model_base_from_png_metadata(
|
||||
file: str, folder: str = "models"
|
||||
) -> str:
|
||||
custom = ""
|
||||
|
||||
# Remove extension from file info
|
||||
if file.endswith(".safetensors") or file.endswith(".ckpt"):
|
||||
file = Path(file).stem
|
||||
# Check for the file name match with one of the local ckpt or safetensors files
|
||||
if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file():
|
||||
custom = file + ".ckpt"
|
||||
if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file():
|
||||
custom = file + ".safetensors"
|
||||
|
||||
return custom
|
||||
|
||||
|
||||
def find_model_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> tuple[str, str]:
|
||||
png_hf_id = ""
|
||||
png_custom = ""
|
||||
|
||||
if key in metadata:
|
||||
model_file = metadata[key]
|
||||
png_custom = try_find_model_base_from_png_metadata(model_file)
|
||||
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
|
||||
if model_file in sd_model_map:
|
||||
png_custom = model_file
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not png_custom and model_file.count("/"):
|
||||
png_hf_id = model_file
|
||||
# No matching model was found
|
||||
if not png_custom and not png_hf_id:
|
||||
print(
|
||||
"Import PNG info: Unable to find a matching model for %s"
|
||||
% model_file
|
||||
)
|
||||
|
||||
return png_custom, png_hf_id
|
||||
|
||||
|
||||
def find_vae_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> str:
|
||||
vae_custom = ""
|
||||
|
||||
if key in metadata:
|
||||
vae_file = metadata[key]
|
||||
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
|
||||
|
||||
# VAE input is optional, should not print or throw an error if missing
|
||||
|
||||
return vae_custom
|
||||
|
||||
|
||||
def find_lora_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> tuple[str, str]:
|
||||
lora_hf_id = ""
|
||||
lora_custom = ""
|
||||
|
||||
if key in metadata:
|
||||
lora_file = metadata[key]
|
||||
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not lora_custom and lora_file.count("/"):
|
||||
lora_hf_id = lora_file
|
||||
|
||||
# LoRA input is optional, should not print or throw an error if missing
|
||||
|
||||
return lora_custom, lora_hf_id
|
||||
|
||||
|
||||
def import_png_metadata(
|
||||
pil_data,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
sampler,
|
||||
cfg_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
):
|
||||
try:
|
||||
png_info = pil_data.info["parameters"]
|
||||
metadata = parse_generation_parameters(png_info)
|
||||
|
||||
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
|
||||
"Model", metadata
|
||||
)
|
||||
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
|
||||
"LoRA", metadata
|
||||
)
|
||||
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
|
||||
|
||||
negative_prompt = metadata["Negative prompt"]
|
||||
steps = int(metadata["Steps"])
|
||||
cfg_scale = float(metadata["CFG scale"])
|
||||
seed = int(metadata["Seed"])
|
||||
width = float(metadata["Size-1"])
|
||||
height = float(metadata["Size-2"])
|
||||
|
||||
if "Model" in metadata and png_custom_model:
|
||||
custom_model = png_custom_model
|
||||
elif "Model" in metadata and png_hf_model_id:
|
||||
custom_model = png_hf_model_id
|
||||
|
||||
if "LoRA" in metadata and lora_custom_model:
|
||||
custom_lora = lora_custom_model
|
||||
hf_lora_id = ""
|
||||
if "LoRA" in metadata and lora_hf_model_id:
|
||||
custom_lora = "None"
|
||||
hf_lora_id = lora_hf_model_id
|
||||
|
||||
if "VAE" in metadata and vae_custom_model:
|
||||
custom_vae = vae_custom_model
|
||||
|
||||
if "Prompt" in metadata:
|
||||
prompt = metadata["Prompt"]
|
||||
if "Sampler" in metadata:
|
||||
if metadata["Sampler"] in scheduler_model_map:
|
||||
sampler = metadata["Sampler"]
|
||||
else:
|
||||
print(
|
||||
"Import PNG info: Unable to find a scheduler for %s"
|
||||
% metadata["Sampler"]
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
if pil_data and pil_data.info.get("parameters"):
|
||||
print("import_png_metadata failed with %s" % ex)
|
||||
pass
|
||||
|
||||
return (
|
||||
None,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
sampler,
|
||||
cfg_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
)
|
||||
41
apps/shark_studio/web/utils/state.py
Normal file
41
apps/shark_studio/web/utils/state.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
import gc
|
||||
|
||||
|
||||
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
|
||||
print(f"Getting status label for {tab_name}")
|
||||
if batch_index < batch_count:
|
||||
bs = f"x{batch_size}" if batch_size > 1 else ""
|
||||
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
|
||||
else:
|
||||
return f"{tab_name} complete"
|
||||
|
||||
|
||||
def get_generation_text_info(seeds, device):
|
||||
cfg_dump = {}
|
||||
for cfg in global_obj.get_config_dict():
|
||||
cfg_dump[cfg] = cfg
|
||||
text_output = f"prompt={cfg_dump['prompts']}"
|
||||
text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}"
|
||||
text_output += (
|
||||
f"\nmodel_id={cfg_dump['hf_model_id']}, "
|
||||
f"ckpt_loc={cfg_dump['ckpt_loc']}"
|
||||
)
|
||||
text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={cfg_dump['steps']}, "
|
||||
f"guidance_scale={cfg_dump['guidance_scale']}, "
|
||||
f"seed={seeds}"
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, "
|
||||
if not cfg_dump.use_hiresfix
|
||||
else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, "
|
||||
)
|
||||
text_output += (
|
||||
f"batch_count={cfg_dump['batch_count']}, "
|
||||
f"batch_size={cfg_dump['batch_size']}, "
|
||||
f"max_length={cfg_dump['max_length']}"
|
||||
)
|
||||
|
||||
return text_output
|
||||
77
apps/shark_studio/web/utils/tmp_configs.py
Normal file
77
apps/shark_studio/web/utils/tmp_configs.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import shutil
|
||||
from time import time
|
||||
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
def clear_tmp_mlir():
|
||||
cleanup_start = time()
|
||||
print(
|
||||
"Clearing .mlir temporary files from a prior run. This may take some time..."
|
||||
)
|
||||
mlir_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.endswith(".mlir")
|
||||
]
|
||||
for filename in mlir_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
|
||||
def clear_tmp_imgs():
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
|
||||
|
||||
print(
|
||||
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
|
||||
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
|
||||
)
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
|
||||
cleanup_start = time()
|
||||
print(
|
||||
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
|
||||
)
|
||||
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
|
||||
print(
|
||||
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
# older SHARK versions had to workaround gradio bugs and stored things differently
|
||||
else:
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
print(
|
||||
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
|
||||
)
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
else:
|
||||
print("No temporary images files to clear.")
|
||||
|
||||
|
||||
def config_tmp():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
|
||||
clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
Reference in New Issue
Block a user