Enable UI / bugfixes / tweaks

This commit is contained in:
Ean Garvey
2023-12-11 23:19:01 -06:00
parent ab32bfbe61
commit b3d5add7f6
29 changed files with 2648 additions and 466 deletions

View File

@@ -0,0 +1,134 @@
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
class control_adapter:
def __init__(
self,
model: str,
):
self.model = None
def export_control_adapter_model(model_keyword):
return None
def export_xl_control_adapter_model(model_keyword):
return None
class preprocessors:
def __init__(
self,
model: str,
):
self.model = None
def export_controlnet_model(model_keyword):
return None
control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {
"initializer": control_adapter.export_control_adapter_model
},
"scribble": {
"initializer": control_adapter.export_control_adapter_model
},
"zoedepth": {
"initializer": control_adapter.export_control_adapter_model
},
},
"sdxl": {
"canny": {
"initializer": control_adapter.export_xl_control_adapter_model
},
},
}
preprocessor_model_map = {
"canny": {"initializer": preprocessors.export_controlnet_model},
"openpose": {"initializer": preprocessors.export_controlnet_model},
"scribble": {"initializer": preprocessors.export_controlnet_model},
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
}
class PreprocessorModel:
def __init__(
self,
hf_model_id,
device,
):
self.model = None
def compile(self, device):
print("compile not implemented for preprocessor.")
return
def run(self, inputs):
print("run not implemented for preprocessor.")
return
def cnet_preview(model, input_img, stencils, images, preprocessed_hints):
if isinstance(input_image, PIL.Image.Image):
img_dict = {
"background": None,
"layers": [None],
"composite": input_image,
}
input_image = EditorValue(img_dict)
images[index] = input_image
if model:
stencils[index] = model
match model:
case "canny":
canny = CannyDetector()
result = canny(
np.array(input_image["composite"]),
100,
200,
)
preprocessed_hints[index] = Image.fromarray(result)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
)
case "openpose":
openpose = OpenposeDetector()
result = openpose(np.array(input_image["composite"]))
preprocessed_hints[index] = Image.fromarray(result[0])
return (
Image.fromarray(result[0]),
stencils,
images,
preprocessed_hints,
)
case "zoedepth":
zoedepth = ZoeDetector()
result = zoedepth(np.array(input_image["composite"]))
preprocessed_hints[index] = Image.fromarray(result)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
)
case "scribble":
preprocessed_hints[index] = input_image["composite"]
return (
input_image["composite"],
stencils,
images,
preprocessed_hints,
)
case _:
preprocessed_hints[index] = None
return (
None,
stencils,
images,
preprocessed_hints,
)

View File

@@ -26,22 +26,21 @@ def imports():
startup_timer.record("import gradio")
# from apps.shark_studio.modules import shared_init
# shared_init.initialize()
# startup_timer.record("initialize shared")
import apps.shark_studio.web.utils.globals as global_obj
global_obj._init()
startup_timer.record("initialize globals")
from apps.shark_studio.modules import (
processing,
gradio_extensons,
ui,
img_processing,
) # noqa: F401
from apps.shark_studio.modules.schedulers import scheduler_model_map
startup_timer.record("other imports")
def initialize():
configure_sigint_handler()
configure_opts_onchange()
# from apps.shark_studio.modules import modelloader
# modelloader.cleanup_models()

View File

@@ -1,9 +1,13 @@
from turbine_models.custom_models.sd_inference import clip, unet, vae
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.api.utils import get_resource_path
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.modules.pipeline import SharkPipelineBase
import iree.runtime as ireert
import gc
import torch
import gradio as gr
sd_model_map = {
"CompVis/stable-diffusion-v1-4": {
@@ -86,16 +90,15 @@ sd_model_map = {
class StableDiffusion(SharkPipelineBase):
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
# a list of necessary modules or a combined "pipeline module" for a
# specified job based on the inference task.
#
#
# custom_model_ids: a dict of submodel + HF ID pairs for custom submodels.
# e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"}
#
#
# embeddings: a dict of embedding checkpoints or model IDs to use when
# initializing the compiled modules.
@@ -107,7 +110,6 @@ class StableDiffusion(SharkPipelineBase):
precision: str = "fp16",
device: str = None,
custom_model_map: dict = {},
custom_weights_map: dict = {},
embeddings: dict = {},
import_ir: bool = True,
):
@@ -118,12 +120,185 @@ class StableDiffusion(SharkPipelineBase):
self.iree_module_dict = None
self.get_compiled_map()
def prepare_pipeline(self, scheduler, custom_model_map):
return None
def generate_images(
self,
prompt,
):
return result_output,
self,
prompt,
negative_prompt,
steps,
strength,
guidance_scale,
seed,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
preprocessed_hints,
):
return None, None, None, None, None
# NOTE: Each `hf_model_id` should have its own starting configuration.
# model_vmfb_key = ""
def shark_sd_fn(
prompt,
negative_prompt,
image_dict,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
base_model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
device: str,
lora_weights: str | list,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
stencils: list,
images: list,
preprocessed_hints: list,
progress=gr.Progress(),
):
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
continue
elif stencil is None and any(
img is not None for img in [images[i], preprocessed_hints[i]]
):
images[i] = None
preprocessed_hints[i] = None
elif images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
images[i] = images[i].convert("RGB")
if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
elif image_dict:
image = image_dict["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
(
image,
_,
_,
) = resize_stencil(image, width, height)
is_img2img = True
print("Performing Stable Diffusion Pipeline setup...")
device_id = None
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
custom_model_map = {}
if custom_weights != "None":
custom_model_map["unet"] = {"custom_weights": custom_weights}
if custom_vae != "None":
custom_model_map["vae"] = {"custom_weights": custom_vae}
if stencils:
for i, stencil in enumerate(stencils):
if "xl" not in base_model_id.lower():
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
"runwayml/stable-diffusion-v1-5"
][stencil]
else:
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
"stabilityai/stable-diffusion-xl-1.0"
][stencil]
submit_pipe_kwargs = {
"base_model_id": base_model_id,
"height": height,
"width": width,
"precision": precision,
"device": device,
"custom_model_map": custom_model_map,
"import_ir": cmd_opts.import_mlir,
"is_img2img": is_img2img,
}
submit_prep_kwargs = {
"scheduler": scheduler,
"custom_model_map": custom_model_map,
"embeddings": lora_weights,
}
submit_run_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"steps": steps,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"preprocessed_hints": preprocessed_hints,
}
global sd_pipe
global sd_pipe_kwargs
if sd_pipe_kwargs and sd_pipe_kwargs != submit_pipe_kwargs:
sd_pipe = None
sd_pipe_kwargs = submit_pipe_kwargs
gc.collect()
if sd_pipe is None:
history[-1][-1] = "Getting the pipeline ready..."
yield history, ""
# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
# which is currently MLIR in the torch dialect.
sd_pipe = SharkStableDiffusionPipeline(
**submit_pipe_kwargs,
)
sd_pipe.prepare_pipe(**submit_prep_kwargs)
for prompt, msg, exec_time in progress.tqdm(
out_imgs=sd_pipe.generate_images(**submit_run_kwargs),
desc="Generating Image...",
):
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
save_output_img(
out_imgs[0],
seeds[current_batch],
extra_info,
)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
), stencils, images
return generated_imgs, text_output, "", stencils, images
def cancel_sd():
print("Inject call to cancel longer API calls.")
return
if __name__ == "__main__":
sd = StableDiffusion(

View File

@@ -2,6 +2,7 @@ import os
import sys
import os
import numpy as np
import glob
from random import (
randint,
seed as seed_random,
@@ -12,6 +13,19 @@ from random import (
from pathlib import Path
from safetensors.torch import load_file
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
# TODO: migrate these utils to studio
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
checkpoints_filetypes = (
"*.ckpt",
"*.safetensors",
)
def get_available_devices():
@@ -75,32 +89,119 @@ def get_available_devices():
return available_devices
def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
cmd_opts.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_vulkan_target_triple}."
)
elif "cuda" in cmd_opts.device:
cmd_opts.device = "cuda"
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_metal_target_platform}."
)
elif "cpu" in cmd_opts.device:
cmd_opts.device = "cpu"
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if cmd_opts.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if cmd_opts.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
def get_generated_imgs_path() -> Path:
return Path(
cmd_opts.output_dir
if cmd_opts.output_dir
cmd_opts.output_dir
if cmd_opts.output_dir
else get_resource_path("..\web\generated_imgs")
)
)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
def get_checkpoints_path(model = ""):
def create_checkpoint_folders():
dir = ["vae", "lora"]
if not cmd_opts.ckpt_dir:
dir.insert(0, "models")
else:
if not os.path.isdir(cmd_opts.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, "
f"{args.ckpt_dir} folder does not exists."
)
for root in dir:
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
def get_checkpoints_path(model=""):
return get_resource_path(f"..\web\models\{model}")
def get_checkpoints(path):
files = []
for file in
def get_checkpoints(model="models"):
ckpt_files = []
file_types = checkpoints_filetypes
if model == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_checkpoints_path(model), extn))
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_checkpoint_pathfile(checkpoint_name, model="models"):
return os.path.join(get_checkpoints_path(model), checkpoint_name)
def get_device_mapping(driver, key_combination=3):
@@ -142,6 +243,30 @@ def get_device_mapping(driver, key_combination=3):
return device_map
def get_opt_flags(model, precision="fp16"):
iree_flags = []
if len(cmd_opts.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
)
if "rocm" in cmd_opts.device:
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
if cmd_opts.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if cmd_opts.data_tiling == False:
iree_flags.append("--iree-opt-data-tiling=False")
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user
selected execution device
@@ -248,6 +373,7 @@ def parse_seed_input(seed_input: str | list | int):
"Seed input must be an integer or an array of integers in JSON format"
)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
@@ -258,6 +384,7 @@ def sanitize_seed(seed: int | str):
seed = randint(uint32_min, uint32_max)
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):

View File

@@ -0,0 +1,66 @@
import os
import json
import re
from pathlib import Path
from omegaconf import OmegaConf
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = os.path.join("diffusers", path.stem)
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights, is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
def convert_original_vae(vae_checkpoint):
vae_state_dict = {}
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
vae_state_dict, vae_config
)
return converted_vae_checkpoint

View File

@@ -1,5 +1,10 @@
import os
import sys
import torch
import json
import safetensors
from safetensors.torch import load_file
from apps.shark_studio.api.utils import get_checkpoint_pathfile
def processLoRA(model, use_lora, splitting_prefix):
@@ -109,3 +114,58 @@ def update_lora_weight(model, use_lora, model_name):
return processLoRA(model, use_lora, "lora_te_")
except:
return None
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_checkpoint_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get(
"img_count", frequencies[0][1]
)
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}

View File

@@ -1,4 +1,8 @@
from
import os
import sys
from PIL import Image
from pathlib import Path
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info=None):
@@ -10,43 +14,45 @@ def save_output_img(output_img, img_seed, extra_info=None):
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", cmd_opts.prompts[0][:15])
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = Path(os.path.basename(args.ckpt_loc)).stem
img_model = cmd_opts.hf_model_id
if cmd_opts.ckpt_loc:
img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem
img_vae = None
if args.custom_vae:
img_vae = Path(os.path.basename(args.custom_vae)).stem
if cmd_opts.custom_vae:
img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem
img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
if cmd_opts.use_lora:
img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem
if args.output_img_format == "jpg":
if cmd_opts.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
if cmd_opts.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
if args.use_hiresfix:
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
if cmd_opts.use_hiresfix:
png_size_text = (
f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
)
else:
png_size_text = f"{args.width}x{args.height}"
png_size_text = f"{cmd_opts.width}x{cmd_opts.height}"
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}"
f"\nNegative prompt: {args.negative_prompts[0]}"
f"\nSteps: {args.steps},"
f"Sampler: {args.scheduler}, "
f"CFG scale: {args.guidance_scale}, "
f"{cmd_opts.prompts[0]}"
f"\nNegative prompt: {cmd_opts.negative_prompts[0]}"
f"\nSteps: {cmd_opts.steps},"
f"Sampler: {cmd_opts.scheduler}, "
f"CFG scale: {cmd_opts.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {png_size_text}, "
f"Model: {img_model}, "
@@ -56,9 +62,9 @@ def save_output_img(output_img, img_seed, extra_info=None):
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
if cmd_opts.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not "
f"[ERROR] Format {cmd_opts.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
@@ -68,18 +74,20 @@ def save_output_img(output_img, img_seed, extra_info=None):
# importance for each data point. Something to consider.
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SCHEDULER": cmd_opts.scheduler,
"PROMPT": cmd_opts.prompts[0],
"NEG_PROMPT": cmd_opts.negative_prompts[0],
"SEED": img_seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height
if not args.use_hiresfix
else args.hiresfix_height,
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
"MAX_LENGTH": args.max_length,
"CFG_SCALE": cmd_opts.guidance_scale,
"PRECISION": cmd_opts.precision,
"STEPS": cmd_opts.steps,
"HEIGHT": cmd_opts.height
if not cmd_opts.use_hiresfix
else cmd_opts.hiresfix_height,
"WIDTH": cmd_opts.width
if not cmd_opts.use_hiresfix
else cmd_opts.hiresfix_width,
"MAX_LENGTH": cmd_opts.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
"LORA": img_lora,
@@ -95,37 +103,23 @@ def save_output_img(output_img, img_seed, extra_info=None):
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
if cmd_opts.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={args.steps}, "
f"guidance_scale={args.guidance_scale}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={args.height}x{args.width}, "
if not args.use_hiresfix
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
)
text_output += (
f"batch_count={args.batch_count}, "
f"batch_size={args.batch_size}, "
f"max_length={args.max_length}"
)
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
"Nearest Neighbor": Image.Resampling.NEAREST,
"Bilinear": Image.Resampling.BILINEAR,
"Bicubic": Image.Resampling.BICUBIC,
"Hamming": Image.Resampling.HAMMING,
"Box": Image.Resampling.BOX,
}
return text_output
resampler_list = resamplers.keys()
# For stencil, the input image can be of any size, but we need to ensure that
@@ -133,7 +127,7 @@ def get_generation_text_info(seeds, device):
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image, width, height):
def resize_stencil(image: Image.Image, width, height, resampler_type=None):
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:
@@ -166,6 +160,9 @@ def resize_stencil(image: Image.Image, width, height):
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
if resampler_type in resamplers:
resampler = resamplers[resampler_type]
else:
resampler = resamplers["Nearest Neighbor"]
new_image = image.resize((n_width, n_height), resampler=resampler)
return new_image, n_width, n_height

View File

@@ -0,0 +1,71 @@
from shark.iree_utils.compile_utils import get_iree_compiled_module
class SharkPipelineBase:
# This class is a lightweight base for managing an
# inference API class. It should provide methods for:
# - compiling a set (model map) of torch IR modules
# - preparing weights for an inference job
# - loading weights for an inference job
# - utilites like benchmarks, tests
def __init__(
self,
model_map: dict,
device: str,
import_mlir: bool = True,
):
self.model_map = model_map
self.device = device
self.import_mlir = import_mlir
def import_torch_ir(self, base_model_id):
for submodel in self.model_map:
hf_id = (
submodel["custom_hf_id"]
if submodel["custom_hf_id"]
else base_model_id
)
torch_ir = submodel["initializer"](
hf_id, **submodel["init_kwargs"], compile_to="torch"
)
submodel["tempfile_name"] = get_resource_path(
f"{submodel}.torch.tempfile"
)
with open(submodel["tempfile_name"], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
def load_vmfb(self, submodel):
if self.iree_module_dict[submodel]:
print(
f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}"
)
elif self.model_map[submodel]["tempfile_name"]:
submodel["tempfile_name"]
return submodel["vmfb"]
def merge_custom_map(self, custom_model_map):
for submodel in custom_model_map:
for key in submodel:
self.model_map[submodel][key] = key
print(self.model_map)
def get_compiled_map(self, device) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
for submodel in self.model_map:
if not self.iree_module_dict[submodel][vmfb]:
self.iree_module_dict[submodel] = get_iree_compiled_module(
submodel.tempfile_name,
device=self.device,
frontend="torch",
)
# TODO: delete the temp file
def run(self, submodel, inputs):
return
def safe_name(name):
return name.replace("/", "_").replace("-", "_")

View File

@@ -2,7 +2,7 @@ import argparse
import os
from pathlib import Path
from apps.stable_diffusion.src.utils.resamplers import resampler_list
from apps.shark_studio.modules.img_processing import resampler_list
def path_expand(s):
@@ -36,7 +36,7 @@ p.add_argument(
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smokes coming off the tires, front "
"mountains at high speeds with smoke coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
@@ -306,21 +306,6 @@ p.add_argument(
"downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="Attempts to load the model from a precompiled flat-buffer "
"and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="Saves the compiled flat-buffer to the local directory.",
)
p.add_argument(
"--use_tuned",
default=False,
@@ -446,7 +431,7 @@ p.add_argument(
)
p.add_argument(
"--ondemand",
"--lowvram",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM.",
@@ -469,10 +454,10 @@ p.add_argument(
)
p.add_argument(
"--autogen",
type=bool,
default="False",
help="Only used for a gradio workaround.",
"--custom_model_map",
type=str,
default="",
help="path to custom model map to import. This should be a .json file",
)
##############################################################################
# IREE - Vulkan supported flags
@@ -612,6 +597,13 @@ p.add_argument(
# Web UI flags
##############################################################################
p.add_argument(
"--webui",
default=True,
action=argparse.BooleanOptionalAction,
help="controls whether the webui is launched.",
)
p.add_argument(
"--progress_bar",
default=True,
@@ -764,8 +756,8 @@ p.add_argument(
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
)
args, unknown = p.parse_known_args()
if args.import_debug:
cmd_opts, unknown = p.parse_known_args()
if cmd_opts.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), args.hf_model_id.replace("/", "_")
os.getcwd(), cmd_opts.hf_model_id.replace("/", "_")
)

View File

@@ -15,7 +15,7 @@ from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from apps.shark_studio. import sd_samplers, postprocessing, errors, restart
from apps.shark_studio.modules.img_processing import sampler_list
from sdapi_v1 import shark_sd_api
from api.llm import chat_api
@@ -26,15 +26,21 @@ def decode_base64_to_image(encoding):
raise HTTPException(status_code=500, detail="Requests not allowed")
if opts.api_forbid_local_requests and not verify_url(encoding):
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
raise HTTPException(
status_code=500, detail="Request to local resource not allowed"
)
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
headers = (
{"user-agent": opts.api_useragent} if opts.api_useragent else {}
)
response = requests.get(encoding, timeout=30, headers=headers)
try:
image = Image.open(BytesIO(response.content))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid image url") from e
raise HTTPException(
status_code=500, detail="Invalid image url"
) from e
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
@@ -42,32 +48,54 @@ def decode_base64_to_image(encoding):
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
raise HTTPException(
status_code=500, detail="Invalid encoded image"
) from e
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
if opts.samples_format.lower() == 'png':
if opts.samples_format.lower() == "png":
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
image.save(
output_bytes,
format="PNG",
pnginfo=(metadata if use_metadata else None),
quality=opts.jpeg_quality,
)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
if image.mode == "RGBA":
image = image.convert("RGB")
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
})
parameters = image.info.get("parameters", None)
exif_bytes = piexif.dump(
{
"Exif": {
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(
parameters or "", encoding="unicode"
)
}
}
)
if opts.samples_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
image.save(
output_bytes,
format="JPEG",
exif=exif_bytes,
quality=opts.jpeg_quality,
)
else:
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
image.save(
output_bytes,
format="WEBP",
exif=exif_bytes,
quality=opts.jpeg_quality,
)
else:
raise HTTPException(status_code=500, detail="Invalid image format")
@@ -80,10 +108,11 @@ def encode_pil_to_base64(image):
def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
@@ -95,35 +124,49 @@ def api_middleware(app: FastAPI):
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get('http_version', '0.0'),
cli=req.scope.get('client', ('0:0.0.0', 0))[0],
prot=req.scope.get('scheme', 'err'),
method=req.scope.get('method', 'err'),
endpoint=endpoint,
duration=duration,
))
endpoint = req.scope.get("path", "err")
if shared.cmd_opts.api_log and endpoint.startswith("/sdapi"):
print(
"API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get("http_version", "0.0"),
cli=req.scope.get("client", ("0:0.0.0", 0))[0],
prot=req.scope.get("scheme", "err"),
method=req.scope.get("method", "err"),
endpoint=endpoint,
duration=duration,
)
)
return res
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get('detail', ''),
"body": vars(e).get('body', ''),
"detail": vars(e).get("detail", ""),
"body": vars(e).get("body", ""),
"errors": str(e),
}
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
if not isinstance(
e, HTTPException
): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
console.print_exception(
show_locals=True,
max_frames=2,
extra_lines=1,
suppress=[anyio, starlette],
word_wrap=False,
width=min([console.width, 200]),
)
else:
errors.report(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
return JSONResponse(
status_code=vars(e).get("status_code", 500),
content=jsonable_encoder(err),
)
@app.middleware("http")
async def exception_handling(request: Request, call_next):
@@ -143,52 +186,48 @@ def api_middleware(app: FastAPI):
class ApiCompat:
def __init__(self, queue_lock: Lock):
self.router = APIRouter()
self.app = FastAPI()
self.queue_lock = queue_lock
api_middleware(self.app)
self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["post"])
self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["post"])
#self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"])
#self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
#self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
#self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
#self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
#self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
#self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
#self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
#self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
#self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
#self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
#self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
#self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
#self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
#self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
#self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
#self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
#self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
#self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
#self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
#self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
#self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
#self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
#self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
#self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
#self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
#self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
#self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
#self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
#self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
#self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
#self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
#self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"])
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
# self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
# self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
# self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
# self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
# self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
# self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
# self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
# self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
# self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
# self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
# self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
# self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
# self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
# self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
# self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
# self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
# self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
# self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
# self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
# self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
# self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
# self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
# chat APIs needed for compatibility with multiple extensions using OpenAI API
self.add_api_route(
"/v1/chat/completions", chat_api, methods=["post"]
)
self.add_api_route("/v1/chat/completions", chat_api, methods=["post"])
self.add_api_route("/v1/completions", chat_api, methods=["post"])
self.add_api_route("/chat/completions", chat_api, methods=["post"])
self.add_api_route("/completions", chat_api, methods=["post"])
@@ -196,16 +235,26 @@ class ApiCompat:
"/v1/engines/codegen/completions", chat_api, methods=["post"]
)
if studio.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_studio, methods=["POST"])
self.add_api_route("/sdapi/v1/server-restart", self.restart_studio, methods=["POST"])
self.add_api_route("/sdapi/v1/server-stop", self.stop_studio, methods=["POST"])
self.add_api_route(
"/sdapi/v1/server-kill", self.kill_studio, methods=["POST"]
)
self.add_api_route(
"/sdapi/v1/server-restart",
self.restart_studio,
methods=["POST"],
)
self.add_api_route(
"/sdapi/v1/server-stop", self.stop_studio, methods=["POST"]
)
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def add_api_route(self, path:str, endpoint, **kwargs):
def add_api_route(self, path: str, endpoint, **kwargs):
if studio.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs
return self.app.add_api_route(
path, endpoint, dependencies=[Depends(self.auth)], **kwargs
)
return self.app.add_api_route(path, endpoint, **kwargs)
def refresh_checkpoints(self):
@@ -231,7 +280,13 @@ class ApiCompat:
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=studio.cmd_opts.timeout_keep_alive, root_path=root_path)
uvicorn.run(
self.app,
host=server_name,
port=port,
timeout_keep_alive=studio.cmd_opts.timeout_keep_alive,
root_path=root_path,
)
def kill_studio(self):
restart.stop_program()
@@ -246,7 +301,7 @@ class ApiCompat:
studio.state.begin(job="preprocess")
preprocess(**args)
studio.state.end()
return models.PreprocessResponse(info='preprocess complete')
return models.PreprocessResponse(info="preprocess complete")
except:
studio.state.end()

View File

@@ -0,0 +1 @@
{}

View File

@@ -3,12 +3,13 @@ import os
import time
import sys
import logging
import apps.shark_studio.api.initializers as initialize
from ui.chat import chat_element
from ui.sd import sd_element
from ui.outputgallery import outputgallery_element
from modules import timer, initialize
from apps.shark_studio.modules import timer
startup_timer = timer.startup_timer
startup_timer.record("launcher")
@@ -72,15 +73,13 @@ def launch_webui(address):
def webui():
from apps.shark_studio.shared_cmd_options import cmd_opts
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
logging.basicConfig(level=logging.DEBUG)
launch_api = cmd_opts.api
initialize.initialize()
from modules import shared, ui_tempdir, script_callbacks, ui, progress
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
@@ -131,16 +130,23 @@ def webui():
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.shark_studio.web.initializers import (
config_gradio_tmp_imgs_folder,
create_custom_models_folders,
from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
)
from apps.shark_studio.api.utils import (
create_checkpoint_folders,
)
config_gradio_tmp_imgs_folder()
import gradio as gr
config_tmp()
clear_tmp_mlir()
clear_tmp_imgs()
# Create custom models folders if they don't exist
create_custom_models_folders()
create_checkpoint_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
@@ -149,10 +155,7 @@ def webui():
dark_theme = resource_path("ui/css/sd_dark_theme.css")
from apps.shark_studio.web.ui import load_ui_from_script
# init global sd pipeline and config
studio.state._init()
# from apps.shark_studio.web.ui import load_ui_from_script
def register_button_click(button, selectedid, inputs, outputs):
button.click(
@@ -209,9 +212,9 @@ def webui():
if __name__ == "__main__":
from apps.shark_studio.shared_cmd_options import cmd_opts
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if cmd_opts.nowebui:
if cmd_opts.webui == False:
api_only()
else:
webui()

View File

@@ -0,0 +1,55 @@
from apps.shark_studio.web.ui.utils import (
HSLHue,
hsl_color,
)
from apps.shark_studio.modules.embeddings import get_lora_metadata
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_file):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
if lora_file == "None":
return ["<div><i>No LoRA selected</i></div>"]
elif not lora_file.lower().endswith(".safetensors"):
return [
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
return [
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
elif metadata is None:
return [
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
else:
return [
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]

View File

@@ -0,0 +1,324 @@
/*
Apply Gradio dark theme to the default Gradio theme.
Procedure to upgrade the dark theme:
- Using your browser, visit http://localhost:8080/?__theme=dark
- Open your browser inspector, search for the .dark css class
- Copy .dark class declarations, apply them here into :root
*/
:root {
--body-background-fill: var(--background-fill-primary);
--body-text-color: var(--neutral-100);
--color-accent-soft: var(--neutral-700);
--background-fill-primary: var(--neutral-950);
--background-fill-secondary: var(--neutral-900);
--border-color-accent: var(--neutral-600);
--border-color-primary: var(--neutral-700);
--link-text-color-active: var(--secondary-500);
--link-text-color: var(--secondary-500);
--link-text-color-hover: var(--secondary-400);
--link-text-color-visited: var(--secondary-600);
--body-text-color-subdued: var(--neutral-400);
--shadow-spread: 1px;
--block-background-fill: var(--neutral-800);
--block-border-color: var(--border-color-primary);
--block_border_width: None;
--block-info-text-color: var(--body-text-color-subdued);
--block-label-background-fill: var(--background-fill-secondary);
--block-label-border-color: var(--border-color-primary);
--block_label_border_width: None;
--block-label-text-color: var(--neutral-200);
--block_shadow: None;
--block_title_background_fill: None;
--block_title_border_color: None;
--block_title_border_width: None;
--block-title-text-color: var(--neutral-200);
--panel-background-fill: var(--background-fill-secondary);
--panel-border-color: var(--border-color-primary);
--panel_border_width: None;
--checkbox-background-color: var(--neutral-800);
--checkbox-background-color-focus: var(--checkbox-background-color);
--checkbox-background-color-hover: var(--checkbox-background-color);
--checkbox-background-color-selected: var(--secondary-600);
--checkbox-border-color: var(--neutral-700);
--checkbox-border-color-focus: var(--secondary-500);
--checkbox-border-color-hover: var(--neutral-600);
--checkbox-border-color-selected: var(--secondary-600);
--checkbox-border-width: var(--input-border-width);
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
--checkbox-label-border-color: var(--border-color-primary);
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
--checkbox-label-border-width: var(--input-border-width);
--checkbox-label-text-color: var(--body-text-color);
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
--error-background-fill: var(--background-fill-primary);
--error-border-color: var(--border-color-primary);
--error_border_width: None;
--error-text-color: #ef4444;
--input-background-fill: var(--neutral-800);
--input-background-fill-focus: var(--secondary-600);
--input-background-fill-hover: var(--input-background-fill);
--input-border-color: var(--border-color-primary);
--input-border-color-focus: var(--neutral-700);
--input-border-color-hover: var(--input-border-color);
--input_border_width: None;
--input-placeholder-color: var(--neutral-500);
--input_shadow: None;
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
--loader_color: None;
--slider_color: None;
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
--table-border-color: var(--neutral-700);
--table-even-background-fill: var(--neutral-950);
--table-odd-background-fill: var(--neutral-900);
--table-row-focus: var(--color-accent-soft);
--button-border-width: var(--input-border-width);
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-cancel-border-color: #dc2626;
--button-cancel-border-color-hover: var(--button-cancel-border-color);
--button-cancel-text-color: white;
--button-cancel-text-color-hover: var(--button-cancel-text-color);
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
--button-primary-border-color: var(--primary-500);
--button-primary-border-color-hover: var(--button-primary-border-color);
--button-primary-text-color: white;
--button-primary-text-color-hover: var(--button-primary-text-color);
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-border-color: var(--neutral-600);
--button-secondary-border-color-hover: var(--button-secondary-border-color);
--button-secondary-text-color: white;
--button-secondary-text-color-hover: var(--button-secondary-text-color);
--block-border-width: 1px;
--block-label-border-width: 1px;
--form-gap-width: 1px;
--error-border-width: 1px;
--input-border-width: 1px;
}
/* SHARK theme */
body {
background-color: var(--background-fill-primary);
}
.generating.svelte-zlszon.svelte-zlszon {
border: none;
}
.generating {
border: none !important;
}
#chatbot {
height: 100% !important;
}
/* display in full width for desktop devices */
@media (min-width: 1536px)
{
.gradio-container {
max-width: var(--size-full) !important;
}
}
.gradio-container .contain {
padding: 0 var(--size-4) !important;
}
#top_logo {
color: transparent;
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea, #negative_prompt_box textarea {
background-color: var(--background-fill-primary) !important;
}
#prompt_examples {
margin: 0 !important;
}
#prompt_examples svg {
display: none !important;
}
#ui_body {
padding: var(--size-2) !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}
#gallery + div {
border-radius: 0 !important;
}
/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */
#gallery .thumbnail-item.thumbnail-lg {
aspect-ratio: unset;
max-height: calc(55vh - (2 * var(--spacing-lg)));
}
@media (min-width: 1921px) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
#gallery .grid-wrap, #gallery .preview{
min-height: calc(768px + 4px + var(--size-14));
max-height: calc(768px + 4px + var(--size-14));
}
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
#gallery .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
}
/* Don't upscale when viewing in solo image mode */
#gallery .preview img {
object-fit: scale-down;
}
/* Navbar images in cover mode*/
#gallery .preview .thumbnail-item img {
object-fit: cover;
}
/* Limit the stable diffusion text output height */
#std_output textarea {
max-height: 215px;
}
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
#gallery .wrap.default {
pointer-events: none;
}
/* Import Png info box */
#txt2img_prompt_image {
height: var(--size-32) !important;
}
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#lora_weights .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {
display: none;
}
/* Hide selected items from ui dropdowns */
#custom_model .options .item .inner-item,
#scheduler .options .item .inner-item,
#device .options .item .inner-item,
#stencil_model .options .item .inner-item {
display:none;
}
/* workarounds for container=false not currently working for dropdowns */
.dropdown_no_container {
padding: 0 !important;
}
#output_subdir_container :first-child {
border: none;
}
/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;
}
/* better clarity when progress bars are minimal */
.meta-text {
background-color: var(--block-label-background-fill);
}
/* lora tag pills */
.lora-tags {
border: 1px solid var(--border-color-primary);
color: var(--block-info-text-color) !important;
padding: var(--block-padding);
}
.lora-tag {
display: inline-block;
height: 2em;
color: rgb(212 212 212) !important;
margin-right: 5pt;
margin-bottom: 5pt;
padding: 2pt 5pt;
border-radius: 5pt;
white-space: nowrap;
}
.lora-model {
margin-bottom: var(--spacing-lg);
color: var(--block-info-text-color) !important;
line-height: var(--line-sm);
}
/* output gallery tab */
.output_parameters_dataframe table.table {
/* works around a gradio bug that always shows scrollbars */
overflow: clip auto;
}
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs);
}
.output_icon_button {
max-width: 30px;
align-self: end;
padding-bottom: 8px;
}
.outputgallery_sendto {
min-width: 7em !important;
}
/* output gallery should take up most of the viewport height regardless of image size/number */
#outputgallery_gallery .fixed-height {
min-height: 89vh !important;
}
/* don't stretch non-square images to be square, breaking their aspect ratio */
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
object-fit: contain !important;
}
/* centered logo for when there are no images */
#top_logo.logo_centered {
height: 100%;
width: 100%;
}
#top_logo.logo_centered img{
object-fit: scale-down;
position: absolute;
width: 80%;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

View File

@@ -0,0 +1,416 @@
import glob
import gradio as gr
import os
import subprocess
import sys
from PIL import Image
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.web.ui.utils import nodlogo_loc
from apps.shark_studio.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
output_dir = get_generated_imgs_path()
def outputgallery_filenames(subdir) -> list[str]:
new_dir_path = os.path.join(output_dir, subdir)
if os.path.exists(new_dir_path):
filenames = [
glob.glob(new_dir_path + "/" + ext)
for ext in ("*.png", "*.jpg", "*.jpeg")
]
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
else:
return []
def output_subdirs() -> list[str]:
# Gets a list of subdirectories of output_dir and below, as relative paths.
relative_paths = [
os.path.relpath(entry[0], output_dir)
for entry in os.walk(
output_dir, followlinks=cmd_opts.output_gallery_followlinks
)
]
# It is less confusing to always including the subdir that will take any
# images generated today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that the date named ones we probably
# created in this or previous sessions come first, sorted with the most
# recent first. Other subdirs are listed after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
result_paths = generated_paths + sorted(
[
path
for path in relative_paths
if (not path.isnumeric()) and path != "."
]
)
return result_paths
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_element:
nod_logo = Image.open(nodlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue:
# https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
subdirectory_paths = gr.State(value=[])
with gr.Column(scale=6):
logo = gr.Image(
label="Getting subdirectories...",
value=nod_logo,
interactive=False,
visible=True,
show_label=True,
elem_id="top_logo",
elem_classes="logo_centered",
show_download_button=False,
)
gallery = gr.Gallery(
label="",
value=gallery_files.value,
visible=False,
show_label=True,
columns=4,
)
with gr.Column(scale=4):
with gr.Group():
with gr.Row():
with gr.Column(
scale=15,
min_width=160,
elem_id="output_subdir_container",
):
subdirectories = gr.Dropdown(
label=f"Subdirectories of {output_dir}",
type="value",
choices=subdirectory_paths.value,
value="",
interactive=True,
elem_classes="dropdown_no_container",
allow_custom_value=True,
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
open_subdir = gr.Button(
variant="secondary",
value="\U0001F5C1", # unicode open folder
interactive=False,
size="sm",
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
refresh = gr.Button(
variant="secondary",
value="\u21BB", # unicode clockwise arrow circle
size="sm",
)
image_columns = gr.Slider(
label="Columns shown", value=4, minimum=1, maximum=16, step=1
)
outputgallery_filename = gr.Textbox(
label="Filename",
value="None",
interactive=False,
show_copy_button=True,
)
with gr.Accordion(
label="Parameter Information", open=False
) as parameters_accordian:
image_parameters = gr.DataFrame(
headers=["Parameter", "Value"],
col_count=2,
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
interactive=True,
)
with gr.Accordion(label="Send To", open=True):
with gr.Row():
outputgallery_sendto_sd = gr.Button(
value="Stable Diffusion",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
# --- Event handlers
def on_clear_gallery():
return [
gr.Gallery(
value=[],
visible=False,
),
gr.Image(
visible=True,
),
]
def on_image_columns_change(columns):
return gr.Gallery(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
new_label = (
f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
)
return [
new_images,
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_open_subdir(subdir):
subdir_path = os.path.normpath(os.path.join(output_dir, subdir))
if os.path.isdir(subdir_path):
if sys.platform == "linux":
subprocess.run(["xdg-open", subdir_path])
elif sys.platform == "darwin":
subprocess.run(["open", subdir_path])
elif sys.platform == "win32":
os.startfile(subdir_path)
def on_refresh(current_subdir: str) -> list:
# get an up-to-date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most
# recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, new_subdir)}"
)
return [
gr.Dropdown(
choices=refreshed_subdirs,
value=new_subdir,
),
refreshed_subdirs,
new_images,
gr.Gallery(
value=new_images, label=new_label, visible=len(new_images) > 0
),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab
# has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as
# new images only go there
if subdir_paths[0] == subdir:
new_images = outputgallery_filenames(subdir)
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, subdir)} - {status}"
)
return [
new_images,
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
else:
# otherwise change nothing,
# (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for
# the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
if params:
if params["source"] == "missing":
return [
"Could not find this image file, refresh the gallery and update the images",
[["Status", "File missing"]],
]
else:
return [
filename,
list(map(list, params["parameters"].items())),
]
return [
filename,
[["Status", "No parameters found"]],
]
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh
# to populate the subdirectory select box and the images from the most
# recent subdirectory.
#
# We do it at this point rather than setting this up in the controls'
# definitions as when you refresh the browser you always get what was
# *initially* set, which won't include any new subdirectories or images
# that might have created since the application was started. Doing it
# this way means a browser refresh/reload always gets the most
# up-to-date data.
def on_select_tab(subdir_paths, request: gr.Request):
local_client = request.headers["host"].startswith(
"127.0.0.1:"
) or request.headers["host"].startswith("localhost:")
if len(subdir_paths) == 0:
return on_refresh("") + [gr.update(interactive=local_client)]
else:
return (
# Change nothing, (only untyped gr.update() does this)
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
# replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,
outputs=[gallery, logo],
queue=False,
)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
)
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
)
image_columns.change(
fn=on_image_columns_change,
inputs=[image_columns],
outputs=[gallery],
queue=False,
)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
)
outputgallery_filename.change(
on_outputgallery_filename_change,
[outputgallery_filename],
[
outputgallery_sendto_sd,
],
queue=False,
)
# We should have been given the .select function for our tab, so set it up
def outputgallery_tab_select(select):
select(
fn=on_select_tab,
inputs=[subdirectory_paths],
outputs=[
subdirectories,
subdirectory_paths,
gallery_files,
gallery,
logo,
open_subdir,
],
queue=False,
)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
# will see that new image if they are looking at today's subdirectory
def outputgallery_watch(components: gr.Textbox):
for component in components:
component.change(
on_new_image,
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
)

View File

@@ -24,130 +24,31 @@ from apps.shark_studio.api.utils import (
)
from apps.shark_studio.api.sd import (
sd_model_map,
StableDiffusion,
)
from apps.shark_studio.api.schedulers import (
scheduler_model_map,
shark_sd_fn,
cancel_sd,
)
from apps.shark_studio.api.controlnet import (
preprocessor_model_map,
control_adapter_model_map,
PreprocessorModel,
cnet_preview,
)
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
from apps.shark_studio.modules.img_processing import (
resampler_list,
resize_stencil,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
get_generation_text_info,
nodlogo_loc,
)
from apps.shark_studio.web.utils.state import (
get_generation_text_info,
status_label,
)
from apps.shark_studio.web.ui.common_events import lora_changed
sd_pipe = None
# NOTE: Each `hf_model_id` should have its own starting configuration.
# model_vmfb_key = ""
def shark_sd_fn(
prompt: str,
negative_prompt: str,
image_dict,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
base_model_id: str,
custom_checkpoints: str,
custom_vae: str,
precision: str,
device: str,
lora_weights: str | list,
lora_hf_ids: str | list,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
stencils: list,
images: list,
preprocessed_hints: list,
progress=gr.Progress(),
):
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
continue
elif stencil is None and any(img is not None for img in [images[i], preprocessed_hints[i]]):
images[i] = None
preprocessed_hints[i] = None
elif images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
images[i] = images[i].convert("RGB")
if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
elif image_dict:
image = image_dict["image"].convert("RGB")
else:
image = None
if image:
image, _, _, = resize_stencil(image, width, height)
device_id = None
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
submit_pipe_kwargs = {
base_model_id: base_model_id,
height: height,
width: width,
precision: precision,
device: device,
extra_model_ids: extra_model_ids,
embeddings: lora_hf_ids,
import_ir: cmd_opts.import_ir,
}
submit_prep_kwargs = {
global sd_pipe
global sd_pipe_kwargs
for key in
if sd_pipe is None:
history[-1][-1] = "Getting the pipeline ready..."
yield history, ""
# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
# which is currently MLIR in the torch dialect.
sd_pipe = SharkStableDiffusionPipeline(
**submit_pipe_kwargs
)
sd_pipe.queue_compile()
for prompt, msg, exec_time in progress.tqdm(
sd_pipe.generate_images(
prompt,
negative_prompt,
),
desc="Generating Image...",
):
return history, ""
def view_json_file(file_obj):
content = ""
@@ -155,17 +56,33 @@ def view_json_file(file_obj):
content = fopen.read()
return content
sd_fn_sig = signature(shark_sd_fn)
max_controlnets = 5
max_controlnets = 3
max_loras = 5
def show_loras(k):
k = int(k)
return [gr.Dropdown(visible=True)]*k + [gr.Dropdown(visible=False, value="None")]*(max_textboxes-k)
return gr.State(
[gr.Dropdown(visible=True)] * k
+ [gr.Dropdown(visible=False, value="None")] * (max_loras - k)
)
def show_controlnets(k):
k = int(k)
return [gr.Row(visible=True)]*k + [gr.Row(visible=False)]*(max_textboxes-k)
return [
gr.State(
[
[gr.Row(visible=True, render=True)] * k
+ [gr.Row(visible=False)] * (max_controlnets - k)
]
),
gr.State([None] * k),
gr.State([None] * k),
gr.State([None] * k),
]
def create_canvas(width, height):
data = Image.fromarray(
@@ -182,10 +99,9 @@ def create_canvas(width, height):
}
return EditorValue(img_dict)
def import_original(original_img, width, height):
resized_img, _, _ = resize_stencil(
original_img, width, height
)
resized_img, _, _ = resize_stencil(original_img, width, height)
img_dict = {
"background": resized_img,
"layers": [resized_img],
@@ -196,6 +112,7 @@ def import_original(original_img, width, height):
crop_size=(width, height),
)
def update_cn_input(
model,
width,
@@ -203,7 +120,6 @@ def update_cn_input(
stencils,
images,
preprocessed_hints,
index,
):
if model == None:
stencils[index] = None
@@ -271,80 +187,99 @@ def update_cn_input(
images,
preprocessed_hints,
]
sd_fn_inputs = []
sd_fn_sig = signature(shark_sd_fn).replace()
for i in sd_fn_sig.parameters:
sd_fn_inputs.append(i)
with gr.Blocks(title="Stable Diffusion") as sd_element:
# Get a list of arguments needed for the API call, then
# initialize an empty list that will manage the corresponding
# gradio values.
inputs_list = gr.State(signature(shark_sd_fn))
inputs_args = gr.State([None] * len(inputs_list))
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
width=150,
height=50,
show_download_button=False,
)
save_sd_config = gr.Button(label="Save Config", scale=1)
load_sd_config = gr.FileExplorer("Load Config", scale=1)
clear_sd_config = gr.ClearButton("Clear Config", scale=1)
with gr.Column(elem_if="ui_body"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row(variant="compact", equal_height=True):
with gr.Column(
scale=1,
elem_id="demo_title_outer",
):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
width=150,
height=50,
show_download_button=False,
)
with gr.Column(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Group()
sd_model_info = (
f"Checkpoint Path: {str(get_checkpoint_path())}"
)
sd_base = gr.Dropdown(
label="Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/stable-diffusion-2.1-base",
choices=get_base_models(),
) # base_model_id
sd_checkpoint = gr.Dropdown(
label="Checkpoints (optional)",
info="Select or enter HF model ID",
elem_id="custom_model",
value="None",
choices=get_checkpoints(sd_base),
) #
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
"\\", "\n\\"
)
sd_vae_info = f"VAE Path: {sd_vae_info}"
sd_custom_vae = gr.Dropdown(
label=f"Custom VAE Models",
info=sd_vae_info,
elem_id="custom_model",
value=os.path.basename(cmd_opts.custom_vae)
if cmd_opts.custom_vae
else "None",
choices=["None"] + get_checkpoints("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
sd_model_info = (
f"Checkpoint Path: {str(get_checkpoints_path())}"
)
sd_base = gr.Dropdown(
label="Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/stable-diffusion-2-1-base",
choices=sd_model_map.keys(),
) # base_model_id
sd_custom_weights = gr.Dropdown(
label="Weights (Optional)",
info="Select or enter HF model ID",
elem_id="custom_model",
value="None",
allow_custom_value=True,
choices=get_checkpoints(sd_base),
) #
with gr.Column(scale=2):
sd_vae_info = (
str(get_checkpoints_path("vae"))
).replace("\\", "\n\\")
sd_vae_info = f"VAE Path: {sd_vae_info}"
sd_custom_vae = gr.Dropdown(
label=f"Custom VAE Models",
info=sd_vae_info,
elem_id="custom_model",
value=os.path.basename(cmd_opts.custom_vae)
if cmd_opts.custom_vae
else "None",
choices=["None"] + get_checkpoints("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Column(scale=1):
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
clear_sd_config = gr.ClearButton(
value="Clear Config", size="sm"
)
load_sd_config = gr.FileExplorer(
label="Load Config",
root=os.path.basename("./configs"),
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
value=cmd_opts.prompts[0],
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
value=cmd_opts.negative_prompts[0],
lines=2,
elem_id="negative_prompt_box",
)
with gr.Accordion(label = "Input Image", open=False):
with gr.Accordion(label="Input Image", open=False):
# TODO: make this import image prompt info if it exists
sd_init_image = gr.Image(
label="Input Image",
@@ -352,41 +287,94 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
height=300,
interactive=True,
)
with gr.Accordion(label="Embeddings options", open=False):
with gr.Accordion(
label="Embeddings options", open=False, render=True
):
sd_lora_info = (
str(get_checkpoints_path("loras"))
).replace("\\", "\n\\")
num_loras = gr.Slider(1, max_loras, value=1, step=1, label="LoRA Count")
loras = []
num_loras = gr.Slider(
1, max_loras, value=1, step=1, label="LoRA Count"
)
loras = gr.State([])
for i in range(max_loras):
lora_opt = gr.Dropdown(
allow_custom_value=False,
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
with gr.Row():
lora_opt = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_checkpoints("lora"),
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
gr.on(
triggers=[lora_opt.change],
fn=lora_changed,
inputs=[lora_opt],
outputs=[lora_tags],
queue=True,
)
loras.value.append(lora_opt)
num_loras.change(show_loras, [num_loras], [loras])
with gr.Accordion(label="Advanced Options", open=True):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list,
choices=scheduler_model_map.keys(),
allow_custom_value=False,
)
with gr.Row():
height = gr.Slider(
384, 768, value=cmd_opts.height, step=8, label="Height"
384,
768,
value=cmd_opts.height,
step=8,
label="Height",
)
width = gr.Slider(
384, 768, value=cmd_opts.width, step=8, label="Width"
384,
768,
value=cmd_opts.width,
step=8,
label="Width",
)
with gr.Row():
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
1,
100,
value=cmd_opts.steps,
step=1,
label="Steps",
)
batch_count = gr.Slider(
1,
100,
value=cmd_opts.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=cmd_opts.batch_size,
step=1,
label="Batch Size",
interactive=True,
visible=True,
)
repeatable_seeds = gr.Checkbox(
cmd_opts.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Column(scale=3):
strength = gr.Slider(
@@ -402,6 +390,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
label="Resample Type",
allow_custom_value=True,
)
guidance_scale = gr.Slider(
0,
50,
value=cmd_opts.guidance_scale,
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
label="Low VRAM",
@@ -416,38 +411,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
],
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=cmd_opts.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=cmd_opts.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
cmd_opts.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
batch_size = gr.Slider(
1,
4,
value=cmd_opts.batch_size,
step=1,
label="Batch Size",
interactive=True,
visible=True,
)
with gr.Row():
seed = gr.Textbox(
value=cmd_opts.seed,
@@ -457,40 +420,53 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
device = gr.Dropdown(
elem_id="device",
label="Device",
value=get_available_devices[0],
choices=get_available_devices,
value=get_available_devices()[0],
choices=get_available_devices(),
allow_custom_value=False,
)
with gr.Accordion(label="Controlnet Options", open=False):
with gr.Accordion(
label="Controlnet Options", open=False, render=False
):
sd_cnet_info = (
str(get_checkpoints_path("controlnet"))
).replace("\\", "\n\\")
num_cnets = gr.Slider(1, max_controlnets, value=1, step=1, label="Controlnet Count")
num_cnets = gr.Slider(
0,
max_controlnets,
value=0,
step=1,
label="Controlnet Count",
)
cnet_rows = []
stencils = []
images = []
preprocessed_hints = []
stencils = gr.State([])
images = gr.State([])
preprocessed_hints = gr.State([])
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
for i in range(max_controlnets):
with gr.Row as cnet_row:
with gr.Row(visible=False) as cnet_row:
with gr.Column():
cnet_gen = gr.Button(
value="Preprocess controlnet input",
)
cnet_processor = gr.Dropdown(
cnet_model = gr.Dropdown(
allow_custom_value=True,
label=f"Controlnet Preprocessor",
label=f"Controlnet Model",
info=sd_cnet_info,
elem_id="lora_weights",
value="None",
choices=["None"] + controlnet_list + get_custom_model_files("controlnet"),
)
cnet_adapter = gr.Dropdown(
allow_custom_value=True,
label=f"Controlnet Adapter",
info=sd_cnet_info,
elem_id="lora_weights",
value="None",
choices=["None"] + controlnet_list + get_custom_model_files("controlnet"),
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
+ get_checkpoints("controlnet"),
)
canvas_width = gr.Slider(
label="Canvas Width",
@@ -529,14 +505,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
visible=True,
label="Preprocessed Hint",
interactive=True,
show_label=True
show_label=True,
)
use_input_img.click(
import_original,
[sd_init_image, canvas_width, canvas_height],
[cnet_image],
[cnet_input],
)
cnet_model.change(
fn=update_cn_input,
inputs=[
@@ -563,7 +538,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
create_canvas,
[canvas_width, canvas_height],
[
cnet_image,
cnet_input,
],
)
gr.on(
@@ -583,12 +558,16 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
preprocessed_hints,
],
)
cnet_rows.append(cnet_row)
cnet_rows.value.append(cnet_row)
num_cnets.change(show_controlnets, num_cnets, cnet_rows)
num_cnets.change(
show_controlnets,
[num_cnets],
[cnet_rows, stencils, images, preprocessed_hints],
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
img2img_gallery = gr.Gallery(
sd_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
@@ -596,14 +575,14 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"{i2i_model_info}\n"
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=False,
)
img2img_status = gr.Textbox(visible=False)
sd_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
@@ -631,12 +610,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
batch_size,
scheduler,
sd_base,
sd_checkpoint,
sd_custom_weights,
sd_custom_vae,
precision,
device,
lora_weights,
lora_hf_id,
loras,
ondemand,
repeatable_seeds,
resample_type,
@@ -652,13 +630,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
stencils,
images,
],
show_progress="minimal" if cmd_opts.progress_bar else "none",
show_progress="minimal",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs),
fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=img2img_status,
outputs=sd_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
@@ -670,10 +648,3 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -1,10 +1,33 @@
def nodlogo_loc():
return "foo"
from enum import IntEnum
import math
import sys
import os
def get_checkpoints_path(model_type: str = None):
return "foo"
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_checkpoints():
return "foo"
nodlogo_loc = resource_path("logos/nod-logo.png")
nodicon_loc = resource_path("logos/nod-icon.png")
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"

View File

@@ -0,0 +1,74 @@
import gc
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def _init():
global _sd_obj
global _config_obj
global _schedulers
_sd_obj = None
_config_obj = None
_schedulers = None
def set_sd_obj(value):
global _sd_obj
_sd_obj = value
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_cfg_obj(value):
global _config_obj
_config_obj = value
def set_schedulers(value):
global _schedulers
_schedulers = value
def get_sd_obj():
global _sd_obj
return _sd_obj
def get_sd_status():
global _sd_obj
return _sd_obj.status
def get_cfg_obj():
global _config_obj
return _config_obj
def get_scheduler(key):
global _schedulers
return _schedulers[key]
def clear_cache():
global _sd_obj
global _config_obj
global _schedulers
del _sd_obj
del _config_obj
del _schedulers
gc.collect()
_sd_obj = None
_config_obj = None
_schedulers = None

View File

@@ -0,0 +1,6 @@
from .png_metadata import (
import_png_metadata,
)
from .display import (
displayable_metadata,
)

View File

@@ -0,0 +1,45 @@
import csv
import os
from .format import humanize, humanizable
def csv_path(image_filename: str):
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)
def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
reader = (
csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]
return matches[0] if matches else {}

View File

@@ -0,0 +1,53 @@
import json
import os
from PIL import Image
from .png_metadata import parse_generation_parameters
from .exif_metadata import has_exif, parse_exif
from .csv_metadata import has_csv, parse_csv
from .format import compact, humanize
def displayable_metadata(image_filename: str) -> dict:
if not os.path.isfile(image_filename):
return {"source": "missing", "parameters": {}}
pil_image = Image.open(image_filename)
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
# and we go via that for SendTo, and is directly tied to the image)
if "parameters" in pil_image.info:
return {
"source": "png",
"parameters": compact(
parse_generation_parameters(pil_image.info["parameters"])
),
}
# we have a matching json file (next most likely to be accurate when it's there)
json_path = os.path.splitext(image_filename)[0] + ".json"
if os.path.isfile(json_path):
with open(json_path) as params_file:
return {
"source": "json",
"parameters": compact(
humanize(json.load(params_file), includes_filename=False)
),
}
# we have a CSV file so try that (can be different shapes, and it usually has no
# headers/param names so of the things we we *know* have parameters, it's the
# last resort)
if has_csv(image_filename):
params = parse_csv(image_filename)
if params: # we might not have found the filename in the csv
return {
"source": "csv",
"parameters": compact(params), # already humanized
}
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
if has_exif(image_filename):
return {"source": "exif", "parameters": parse_exif(pil_image)}
# we've got nothing
return None

View File

@@ -0,0 +1,52 @@
from PIL import Image
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
def has_exif(image_filename: str) -> bool:
return True if Image.open(image_filename).getexif() else False
def parse_exif(pil_image: Image) -> dict:
img_exif = pil_image.getexif()
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
# dependency
exif_tags = {
TAGS.get(key, key): str(val)
for (key, val) in img_exif.items()
if key in TAGS
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
def try_get_ifd(ifd_id):
try:
return img_exif.get_ifd(ifd_id).items()
except KeyError:
return {}
ifd_tags = {
TAGS.get(key, key): str(val)
for ifd_id in IFD
for (key, val) in try_get_ifd(ifd_id)
if ifd_id != IFD.GPSInfo
and key in TAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
gps_tags = {
GPSTAGS.get(key, key): str(val)
for (key, val) in try_get_ifd(IFD.GPSInfo)
if key in GPSTAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
return {**exif_tags, **ifd_tags, **gps_tags}

View File

@@ -0,0 +1,143 @@
# As SHARK has evolved more columns have been added to images_details.csv. However, since
# no version of the CSV has any headers (yet) we don't actually have anything within the
# file that tells us which parameter each column is for. So this is a list of known patterns
# indexed by length which is what we're going to have to use to guess which columns are the
# right ones for the file we're looking at.
# The same ordering is used for JSON, but these do have key names, however they are not very
# human friendly, nor do they match up with the what is written to the .png headers
# So these are functions to try and get something consistent out the raw input from all
# these sources
PARAMS_FORMATS = {
9: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
10: {
"MODEL": "Model",
"VARIANT": "Variant",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
12: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
},
}
PARAMS_FORMAT_CURRENT = {
"VARIANT": "Model",
"VAE": "VAE",
"LORA": "LoRA",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
}
def compact(metadata: dict) -> dict:
# we don't want to alter the original dictionary
result = dict(metadata)
# discard the filename because we should already have it
if result.keys() & {"Filename"}:
result.pop("Filename")
# make showing the sizes more compact by using only one line each
if result.keys() & {"Size-1", "Size-2"}:
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
elif result.keys() & {"Height", "Width"}:
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
hires_y = result.pop("Hires resize-1")
hires_x = result.pop("Hires resize-2")
if hires_x == 0 and hires_y == 0:
result["Hires resize"] = "None"
else:
result["Hires resize"] = f"{hires_y}x{hires_x}"
# remove VAE if it exists and is empty
if (result.keys() & {"VAE"}) and (
not result["VAE"] or result["VAE"] == "None"
):
result.pop("VAE")
# remove LoRA if it exists and is empty
if (result.keys() & {"LoRA"}) and (
not result["LoRA"] or result["LoRA"] == "None"
):
result.pop("LoRA")
return result
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
return lookup_key in PARAMS_FORMATS.keys()
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
# For lists we can only work based on the length, we have no other information
if isinstance(metadata, list):
if humanizable(metadata, includes_filename):
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
else:
raise KeyError(
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
)
# For dictionaries we try to use the matching length parameter format if
# available, otherwise we just use the current format which is assumed to
# have everything currently known about. Then we swap keys in the metadata
# that match keys in the format for the friendlier name that we have set
# in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_CURRENT
return {
format[key]: metadata[key]
for key in format.keys()
if key in metadata.keys() and metadata[key]
}
raise TypeError("Can only humanize parameter lists or dictionaries")

View File

@@ -0,0 +1,222 @@
import re
from pathlib import Path
from apps.shark_studio.api.utils import (
get_checkpoint_pathfile,
)
from apps.shark_studio.api.sd import (
sd_model_map,
)
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
def parse_generation_parameters(x: str):
res = {}
prompt = ""
negative_prompt = ""
done_with_prompt = False
*lines, lastline = x.strip().split("\n")
if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ""
for i, line in enumerate(lines):
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
line = line[16:].strip()
if done_with_prompt:
negative_prompt += ("" if negative_prompt == "" else "\n") + line
else:
prompt += ("" if prompt == "" else "\n") + line
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
res[k + "-1"] = m.group(1)
res[k + "-2"] = m.group(2)
else:
res[k] = v
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
hypernet = res.get("Hypernet", None)
if hypernet is not None:
res[
"Prompt"
] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
return res
def try_find_model_base_from_png_metadata(
file: str, folder: str = "models"
) -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in sd_model_map:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
cfg_scale = float(metadata["CFG scale"])
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
if metadata["Sampler"] in scheduler_model_map:
sampler = metadata["Sampler"]
else:
print(
"Import PNG info: Unable to find a scheduler for %s"
% metadata["Sampler"]
)
except Exception as ex:
if pil_data and pil_data.info.get("parameters"):
print("import_png_metadata failed with %s" % ex)
pass
return (
None,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
)

View File

@@ -0,0 +1,41 @@
import apps.shark_studio.web.utils.globals as global_obj
import gc
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
print(f"Getting status label for {tab_name}")
if batch_index < batch_count:
bs = f"x{batch_size}" if batch_size > 1 else ""
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
else:
return f"{tab_name} complete"
def get_generation_text_info(seeds, device):
cfg_dump = {}
for cfg in global_obj.get_config_dict():
cfg_dump[cfg] = cfg
text_output = f"prompt={cfg_dump['prompts']}"
text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}"
text_output += (
f"\nmodel_id={cfg_dump['hf_model_id']}, "
f"ckpt_loc={cfg_dump['ckpt_loc']}"
)
text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}"
text_output += (
f"\nsteps={cfg_dump['steps']}, "
f"guidance_scale={cfg_dump['guidance_scale']}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, "
if not cfg_dump.use_hiresfix
else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, "
)
text_output += (
f"batch_count={cfg_dump['batch_count']}, "
f"batch_size={cfg_dump['batch_size']}, "
f"max_length={cfg_dump['max_length']}"
)
return text_output

View File

@@ -0,0 +1,77 @@
import os
import shutil
from time import time
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
def clear_tmp_mlir():
cleanup_start = time()
print(
"Clearing .mlir temporary files from a prior run. This may take some time..."
)
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(shark_tmp + filename)
print(
f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
)
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
print(
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
print(
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()