Migrating to JSON requests in SD UI

This commit is contained in:
Ean Garvey
2023-12-12 23:22:11 -06:00
parent 5e675170b8
commit 961a5adda4
12 changed files with 561 additions and 483 deletions

View File

@@ -1,5 +1,15 @@
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
import os
import PIL
import numpy as np
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from datetime import datetime
from PIL import Image
from gradio.components.image_editor import (
EditorValue,
)
class control_adapter:
def __init__(
@@ -57,20 +67,26 @@ class PreprocessorModel:
def __init__(
self,
hf_model_id,
device,
device = "cpu",
):
self.model = None
self.model = hf_model_id
self.device = device
def compile(self, device):
def compile(self):
print("compile not implemented for preprocessor.")
return
def run(self, inputs):
print("run not implemented for preprocessor.")
return
return inputs
def cnet_preview(model, input_img, stencils, images, preprocessed_hints):
def cnet_preview(model, input_image, stencil, preprocessed_hint):
curr_datetime = datetime.now().strftime('%Y-%m-%d.%H-%M-%S')
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
if not os.path.exists(control_imgs_path):
os.mkdir(control_imgs_path)
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
if isinstance(input_image, PIL.Image.Image):
img_dict = {
"background": None,
@@ -78,57 +94,52 @@ def cnet_preview(model, input_img, stencils, images, preprocessed_hints):
"composite": input_image,
}
input_image = EditorValue(img_dict)
images[index] = input_image
preprocessed_hint = img_dest
if model:
stencils[index] = model
stencil = model
match model:
case "canny":
canny = CannyDetector()
canny = PreprocessorModel("canny")
result = canny(
np.array(input_image["composite"]),
100,
200,
)
preprocessed_hints[index] = Image.fromarray(result)
Image.fromarray(result).save(fp=img_dest)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
stencil,
preprocessed_hint,
)
case "openpose":
openpose = OpenposeDetector()
openpose = PreprocessorModel("openpose")
result = openpose(np.array(input_image["composite"]))
preprocessed_hints[index] = Image.fromarray(result[0])
Image.fromarray(result[0]).save(fp=img_dest)
return (
Image.fromarray(result[0]),
stencils,
images,
preprocessed_hints,
stencil,
preprocessed_hint,
)
case "zoedepth":
zoedepth = ZoeDetector()
zoedepth = PreprocessorModel("ZoeDepth")
result = zoedepth(np.array(input_image["composite"]))
preprocessed_hints[index] = Image.fromarray(result)
Image.fromarray(result).save(fp=img_dest)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
stencil,
preprocessed_hint,
)
case "scribble":
preprocessed_hints[index] = input_image["composite"]
input_image["composite"].save(fp=img_dest)
return (
input_image["composite"],
stencils,
images,
preprocessed_hints,
stencil,
preprocessed_hint,
)
case _:
preprocessed_hints[index] = None
preprocessed_hint = None
return (
None,
stencils,
images,
preprocessed_hints,
stencil,
preprocessed_hint,
)

View File

@@ -4,7 +4,7 @@ from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
)
from apps.shark_studio.api.utils import get_resource_path
from apps.shark_studio.web.utils.file_utils import get_resource_path
import iree.runtime as ireert
from itertools import chain
import gc

View File

@@ -1,13 +1,14 @@
from turbine_models.custom_models.sd_inference import clip, unet, vae
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.api.utils import get_resource_path
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.modules.pipeline import SharkPipelineBase
import iree.runtime as ireert
from apps.shark_studio.modules.img_processing import resize_stencil, save_output_img
from math import ceil
import gc
import torch
import gradio as gr
import PIL
import time
sd_model_map = {
"CompVis/stable-diffusion-v1-4": {
@@ -154,7 +155,7 @@ def shark_sd_fn(
steps: int,
strength: float,
guidance_scale: float,
seed: str | int,
seeds: list,
batch_count: int,
batch_size: int,
scheduler: str,
@@ -168,24 +169,13 @@ def shark_sd_fn(
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
stencils: list,
images: list,
preprocessed_hints: list,
sd_json: dict,
progress=gr.Progress(),
):
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
continue
elif stencil is None and any(
img is not None for img in [images[i], preprocessed_hints[i]]
):
images[i] = None
preprocessed_hints[i] = None
elif images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
images[i] = images[i].convert("RGB")
stencils=[]
preprocessed_hints=[]
cnet_strengths=[]
if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
@@ -203,8 +193,6 @@ def shark_sd_fn(
is_img2img = True
print("Performing Stable Diffusion Pipeline setup...")
device_id = None
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
@@ -216,11 +204,11 @@ def shark_sd_fn(
if stencils:
for i, stencil in enumerate(stencils):
if "xl" not in base_model_id.lower():
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
custom_model_map[f"control_adapter_{i}"] = control_adapter_map[
"runwayml/stable-diffusion-v1-5"
][stencil]
else:
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
custom_model_map[f"control_adapter_{i}"] = control_adapter_map[
"stabilityai/stable-diffusion-xl-1.0"
][stencil]
@@ -245,54 +233,70 @@ def shark_sd_fn(
"steps": steps,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
"seeds": seeds,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"preprocessed_hints": preprocessed_hints,
}
global sd_pipe
global sd_pipe_kwargs
if sd_pipe_kwargs and sd_pipe_kwargs != submit_pipe_kwargs:
sd_pipe = None
sd_pipe_kwargs = submit_pipe_kwargs
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
):
print("Regenerating pipeline...")
global_obj.clear_cache()
gc.collect()
if sd_pipe is None:
history[-1][-1] = "Getting the pipeline ready..."
yield history, ""
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
# which is currently MLIR in the torch dialect.
sd_pipe = SharkStableDiffusionPipeline(
sd_pipe = StableDiffusion(
**submit_pipe_kwargs,
)
global_obj.set_sd_obj(sd_pipe)
sd_pipe.prepare_pipe(**submit_prep_kwargs)
generated_imgs = []
for prompt, msg, exec_time in progress.tqdm(
out_imgs=sd_pipe.generate_images(**submit_run_kwargs),
desc="Generating Image...",
):
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
ceil(steps / strength),
strength,
guidance_scale,
seeds[current_batch],
stencils,
resample_type=resample_type,
control_mode=control_mode,
preprocessed_hints=preprocessed_hints,
)
total_time = time.time() - start_time
text_output = []
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
save_output_img(
out_imgs[0],
seeds[current_batch],
extra_info,
sd_json,
)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
), stencils, images
"Image-to-Image", current_batch + 1, batch_count, batch_size
), stencils
return generated_imgs, text_output, "", stencils, images
return generated_imgs, text_output, "", stencil, image
return generated_imgs, text_output, "", stencil, image
def cancel_sd():

View File

@@ -1,7 +1,4 @@
import os
import sys
import numpy as np
import glob
import json
from random import (
randint,
@@ -11,7 +8,6 @@ from random import (
)
from pathlib import Path
from safetensors.torch import load_file
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
@@ -22,11 +18,6 @@ from shark.iree_utils.vulkan_utils import (
get_iree_vulkan_runtime_flags,
)
checkpoints_filetypes = (
"*.ckpt",
"*.safetensors",
)
def get_available_devices():
def get_devices_by_name(driver_name):
@@ -109,6 +100,7 @@ def set_init_device_flags():
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
from shark.iree_utils.metal_utils import get_metal_target_triple
triple = get_metal_target_triple(device_name)
if triple is not None:
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
@@ -150,59 +142,6 @@ def get_all_devices(driver_name):
return device_list_src
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
def get_generated_imgs_path() -> Path:
return Path(
cmd_opts.output_dir
if cmd_opts.output_dir
else get_resource_path("..\web\generated_imgs")
)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
def create_checkpoint_folders():
dir = ["vae", "lora"]
if not cmd_opts.ckpt_dir:
dir.insert(0, "models")
else:
if not os.path.isdir(cmd_opts.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, "
f"{args.ckpt_dir} folder does not exists."
)
for root in dir:
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
def get_checkpoints_path(model=""):
return get_resource_path(f"..\web\models\{model}")
def get_checkpoints(model="models"):
ckpt_files = []
file_types = checkpoints_filetypes
if model == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_checkpoints_path(model), extn))
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_checkpoint_pathfile(checkpoint_name, model="models"):
return os.path.join(get_checkpoints_path(model), checkpoint_name)
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
@@ -250,6 +189,7 @@ def get_opt_flags(model, precision="fp16"):
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
)
if "rocm" in cmd_opts.device:
from shark.iree_utils.gpu_utils import get_iree_rocm_args
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
if cmd_opts.iree_constant_folding == False:

View File

@@ -5,7 +5,7 @@ import json
import safetensors
from dataclasses import dataclass
from safetensors.torch import load_file
from apps.shark_studio.api.utils import get_checkpoint_pathfile
from apps.shark_studio.web.utils.file_utils import get_checkpoint_pathfile, get_path_stem
@dataclass

View File

@@ -1,115 +1,12 @@
import os
import sys
import re
import datetime as dt
import json
from csv import DictWriter
from PIL import Image
from pathlib import Path
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info=None):
if extra_info is None:
extra_info = {}
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", cmd_opts.prompts[0][:15])
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
img_model = cmd_opts.hf_model_id
if cmd_opts.ckpt_loc:
img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem
img_vae = None
if cmd_opts.custom_vae:
img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem
img_lora = None
if cmd_opts.use_lora:
img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem
if cmd_opts.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if cmd_opts.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
if cmd_opts.use_hiresfix:
png_size_text = (
f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
)
else:
png_size_text = f"{cmd_opts.width}x{cmd_opts.height}"
pngInfo.add_text(
"parameters",
f"{cmd_opts.prompts[0]}"
f"\nNegative prompt: {cmd_opts.negative_prompts[0]}"
f"\nSteps: {cmd_opts.steps},"
f"Sampler: {cmd_opts.scheduler}, "
f"CFG scale: {cmd_opts.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_lora}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if cmd_opts.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {cmd_opts.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {
"VARIANT": img_model,
"SCHEDULER": cmd_opts.scheduler,
"PROMPT": cmd_opts.prompts[0],
"NEG_PROMPT": cmd_opts.negative_prompts[0],
"SEED": img_seed,
"CFG_SCALE": cmd_opts.guidance_scale,
"PRECISION": cmd_opts.precision,
"STEPS": cmd_opts.steps,
"HEIGHT": cmd_opts.height
if not cmd_opts.use_hiresfix
else cmd_opts.hiresfix_height,
"WIDTH": cmd_opts.width
if not cmd_opts.use_hiresfix
else cmd_opts.hiresfix_width,
"MAX_LENGTH": cmd_opts.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
"LORA": img_lora,
}
new_entry.update(extra_info)
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if cmd_opts.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
"Nearest Neighbor": Image.Resampling.NEAREST,
@@ -122,6 +19,96 @@ resamplers = {
resampler_list = resamplers.keys()
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info=None):
from apps.shark_studio.web.utils.file_utils import get_generated_imgs_path, get_generated_imgs_todays_subdir
if extra_info is None:
extra_info = {}
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompts"][0][:15])
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
#img_model = cmd_opts.hf_model_id
#if cmd_opts.ckpt_loc:
# img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem
#img_vae = None
#if cmd_opts.custom_vae:
# img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem
#img_lora = None
#if cmd_opts.use_lora:
# img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem
#if cmd_opts.output_img_format == "jpg":
# out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
# output_img.save(out_img_path, quality=95, subsampling=0)
#else:
# out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
# pngInfo = PngImagePlugin.PngInfo()
# if cmd_opts.write_metadata_to_png:
# # Using a conditional expression caused problems, so setting a new
# # variable for now.
# if cmd_opts.use_hiresfix:
# png_size_text = (
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
# )
# else:
# png_size_text = f"{cmd_opts.width}x{cmd_opts.height}"
# pngInfo.add_text(
# "parameters",
# f"{cmd_opts.prompts[0]}"
# f"\nNegative prompt: {cmd_opts.negative_prompts[0]}"
# f"\nSteps: {cmd_opts.steps},"
# f"Sampler: {cmd_opts.scheduler}, "
# f"CFG scale: {cmd_opts.guidance_scale}, "
# f"Seed: {img_seed},"
# f"Size: {png_size_text}, "
# f"Model: {img_model}, "
# f"VAE: {img_vae}, "
# f"LoRA: {img_lora}",
# )
# output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
# if cmd_opts.output_img_format not in ["png", "jpg"]:
# print(
# f"[ERROR] Format {cmd_opts.output_img_format} is not "
# f"supported yet. Image saved as png instead."
# f"Supported formats: png / jpg"
# )
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {
}
new_entry.update(extra_info)
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
# For stencil, the input image can be of any size, but we need to ensure that
# it conforms with our model constraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.

View File

@@ -675,6 +675,13 @@ p.add_argument(
"images under --output_dir in the UI.",
)
p.add_argument(
"--configs_path",
default=None,
type=str,
help="Path to .json config directory."
)
p.add_argument(
"--output_gallery_followlinks",
default=False,

View File

@@ -135,7 +135,7 @@ def webui():
clear_tmp_mlir,
clear_tmp_imgs,
)
from apps.shark_studio.api.utils import (
from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
)

View File

@@ -6,7 +6,7 @@ import sys
from PIL import Image
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import (
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)

View File

@@ -1,9 +1,7 @@
import os
import time
import gradio as gr
import PIL
import json
import sys
import gradio as gr
import numpy as np
from math import ceil
from inspect import signature
@@ -12,15 +10,17 @@ from pathlib import Path
from datetime import datetime as dt
from gradio.components.image_editor import (
Brush,
Eraser,
EditorValue,
)
from apps.shark_studio.api.utils import (
get_available_devices,
)
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_checkpoints_path,
get_checkpoints,
get_configs_path,
)
from apps.shark_studio.api.sd import (
sd_model_map,
@@ -28,8 +28,6 @@ from apps.shark_studio.api.sd import (
cancel_sd,
)
from apps.shark_studio.api.controlnet import (
preprocessor_model_map,
PreprocessorModel,
cnet_preview,
)
from apps.shark_studio.modules.schedulers import (
@@ -44,7 +42,6 @@ from apps.shark_studio.web.ui.utils import (
nodlogo_loc,
)
from apps.shark_studio.web.utils.state import (
get_generation_text_info,
status_label,
)
from apps.shark_studio.web.ui.common_events import lora_changed
@@ -57,32 +54,58 @@ def view_json_file(file_obj):
return content
max_controlnets = 3
max_loras = 5
def submit_to_cnet_config(stencil: str, preprocessed_hint: str, cnet_strength: int, curr_config: dict):
if any(i in [None, ""] for i in [stencil, preprocessed_hint]):
return gr.update()
if curr_config is not None:
if "controlnets" in curr_config:
curr_config["controlnets"]["model"].append(stencil)
curr_config["controlnets"]["hint"].append(preprocessed_hint)
curr_config["controlnets"]["strength"].append(cnet_strength)
return curr_config
cnet_map = {}
cnet_map["controlnets"] = {
"model": [stencil],
"hint": [preprocessed_hint],
"strength": [cnet_strength],
}
return cnet_map
def show_loras(k):
k = int(k)
return gr.State(
[gr.Dropdown(visible=True)] * k
+ [gr.Dropdown(visible=False, value="None")] * (max_loras - k)
)
def update_embeddings_json(embedding, curr_config: dict):
if curr_config is not None:
if "embeddings" in curr_config:
curr_config["embeddings"].append(embedding)
return curr_config
config = {"embeddings": [embedding]}
return config
def show_controlnets(k):
k = int(k)
return [
gr.State(
[
[gr.Row(visible=True, render=True)] * k
+ [gr.Row(visible=False)] * (max_controlnets - k)
]
),
gr.State([None] * k),
gr.State([None] * k),
gr.State([None] * k),
]
def submit_to_main_config(input_cfg: dict, main_cfg: dict):
if main_cfg in [None, ""]:
# only time main_cfg should be a string is empty case.
return input_cfg
for base_key in input_cfg:
main_cfg[base_key] = input_cfg[base_key]
return main_cfg
def save_sd_cfg(config: dict, save_name: str):
if os.path.exists(save_name):
filepath=save_name
elif cmd_opts.configs_path:
filepath=os.path.join(cmd_opts.configs_path, save_name)
else:
filepath=os.path.join(get_configs_path(), save_name)
if ".json" not in filepath:
filepath += ".json"
with open(filepath, mode="w") as f:
f.write(json.dumps(config))
return("...")
def create_canvas(width, height):
data = Image.fromarray(
@@ -101,30 +124,34 @@ def create_canvas(width, height):
def import_original(original_img, width, height):
resized_img, _, _ = resize_stencil(original_img, width, height)
img_dict = {
"background": resized_img,
"layers": [resized_img],
"composite": None,
}
return gr.ImageEditor(
value=EditorValue(img_dict),
crop_size=(width, height),
)
if original_img is None:
resized_img = create_canvas(width, height)
return gr.ImageEditor(
value=resized_img,
crop_size=(width, height),
)
else:
resized_img, _, _ = resize_stencil(original_img, width, height)
img_dict = {
"background": resized_img,
"layers": [resized_img],
"composite": None,
}
return gr.ImageEditor(
value=EditorValue(img_dict),
crop_size=(width, height),
)
def update_cn_input(
model,
width,
height,
stencils,
images,
preprocessed_hints,
stencil,
preprocessed_hint,
):
print("update_cn_input")
if model == None:
stencils[index] = None
images[index] = None
preprocessed_hints[index] = None
stencil = None
preprocessed_hint = None
return [
gr.update(),
gr.update(),
@@ -132,9 +159,8 @@ def update_cn_input(
gr.update(),
gr.update(),
gr.update(),
stencils,
images,
preprocessed_hints,
stencil,
preprocessed_hint,
]
elif model == "scribble":
return [
@@ -160,9 +186,8 @@ def update_cn_input(
gr.Slider(visible=True, label="Canvas Height"),
gr.Button(visible=True),
gr.Button(visible=False),
stencils,
images,
preprocessed_hints,
stencil,
preprocessed_hint,
]
else:
return [
@@ -181,11 +206,10 @@ def update_cn_input(
),
gr.Slider(visible=True, label="Canvas Width"),
gr.Slider(visible=True, label="Canvas Height"),
gr.Button(visible=True),
gr.Button(visible=False),
stencils,
images,
preprocessed_hints,
gr.Button(visible=True),
stencil,
preprocessed_hint,
]
@@ -230,7 +254,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
choices=sd_model_map.keys(),
) # base_model_id
sd_custom_weights = gr.Dropdown(
label="Weights (Optional)",
label="Custom Weights",
info="Select or enter HF model ID",
elem_id="custom_model",
value="None",
@@ -253,17 +277,21 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
allow_custom_value=True,
scale=1,
)
with gr.Column(scale=1):
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
clear_sd_config = gr.ClearButton(
value="Clear Config", size="sm"
)
load_sd_config = gr.FileExplorer(
label="Load Config",
root=os.path.basename("./configs"),
)
with gr.Row():
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
label="Low VRAM",
interactive=True,
)
precision = gr.Radio(
label="Precision",
value=cmd_opts.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -288,40 +316,40 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
interactive=True,
)
with gr.Accordion(
label="Embeddings options", open=False, render=True
label="Embeddings options", open=True, render=True
):
sd_lora_info = (
str(get_checkpoints_path("loras"))
).replace("\\", "\n\\")
num_loras = gr.Slider(
1, max_loras, value=1, step=1, label="LoRA Count"
)
loras = gr.State([])
for i in range(max_loras):
with gr.Row():
lora_opt = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_checkpoints("lora"),
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
gr.on(
triggers=[lora_opt.change],
fn=lora_changed,
inputs=[lora_opt],
outputs=[lora_tags],
queue=True,
).replace("\\", "\n\\")
with gr.Column(scale=2):
lora_opt = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_checkpoints("lora"),
)
loras.value.append(lora_opt)
num_loras.change(show_loras, [num_loras], [loras])
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Column(scale=1):
embeddings_config = gr.JSON()
submit_embeddings = gr.Button("Submit to Main Config", size="sm")
gr.on(
triggers=[lora_opt.change],
fn=lora_changed,
inputs=[lora_opt],
outputs=[lora_tags],
queue=True,
)
gr.on(
triggers=[lora_opt.change],
fn=update_embeddings_json,
inputs=[lora_opt, embeddings_config],
outputs=[embeddings_config],
)
with gr.Accordion(label="Advanced Options", open=True):
with gr.Row():
scheduler = gr.Dropdown(
@@ -331,7 +359,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
choices=scheduler_model_map.keys(),
allow_custom_value=False,
)
with gr.Row():
height = gr.Slider(
384,
768,
@@ -397,20 +424,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
label="Low VRAM",
interactive=True,
)
precision = gr.Radio(
label="Precision",
value=cmd_opts.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
with gr.Row():
seed = gr.Textbox(
value=cmd_opts.seed,
@@ -425,49 +438,52 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
allow_custom_value=False,
)
with gr.Accordion(
label="Controlnet Options", open=False, render=False
label="Controlnet Options", open=False, render=True
):
sd_cnet_info = (
str(get_checkpoints_path("controlnet"))
).replace("\\", "\n\\")
num_cnets = gr.Slider(
0,
max_controlnets,
value=0,
step=1,
label="Controlnet Count",
)
cnet_rows = []
stencils = gr.State([])
images = gr.State([])
preprocessed_hints = gr.State([])
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
with gr.Column():
cnet_config = gr.JSON()
submit_cnet = gr.Button("Submit to Main Config", size="sm")
with gr.Column():
sd_cnet_info = (
str(get_checkpoints_path("controlnet"))
).replace("\\", "\n\\")
stencil = gr.State("")
preprocessed_hint = gr.State("")
with gr.Row():
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
for i in range(max_controlnets):
with gr.Row(visible=False) as cnet_row:
with gr.Column():
cnet_gen = gr.Button(
value="Preprocess controlnet input",
)
cnet_model = gr.Dropdown(
allow_custom_value=True,
label=f"Controlnet Model",
info=sd_cnet_info,
elem_id="lora_weights",
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
+ get_checkpoints("controlnet"),
)
with gr.Row(visible=True) as cnet_row:
with gr.Column():
cnet_gen = gr.Button(
value="Preprocess controlnet input",
)
cnet_model = gr.Dropdown(
allow_custom_value=True,
label=f"Controlnet Model",
info=sd_cnet_info,
elem_id="lora_weights",
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
+ get_checkpoints("controlnet"),
)
cnet_strength = gr.Slider(
label="Controlnet Strength",
minimum=0,
maximum=100,
value=50,
step=1,
)
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
@@ -484,22 +500,24 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
step=1,
visible=False,
)
make_canvas = gr.Button(
value="Make Canvas!",
visible=False,
)
use_input_img = gr.Button(
value="Use Original Image",
visible=False,
)
cnet_input = gr.ImageEditor(
visible=True,
image_mode="RGB",
interactive=True,
show_label=True,
label="Input Image",
type="pil",
make_canvas = gr.Button(
value="Make Canvas!",
visible=False,
)
use_input_img = gr.Button(
value="Use Original Image",
visible=False,
size="sm",
)
cnet_input = gr.ImageEditor(
visible=True,
image_mode="RGB",
interactive=True,
show_label=True,
label="Input Image",
type="pil",
)
with gr.Column():
cnet_output = gr.Image(
value=None,
visible=True,
@@ -507,63 +525,66 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
interactive=True,
show_label=True,
)
use_input_img.click(
import_original,
[sd_init_image, canvas_width, canvas_height],
[cnet_input],
use_result = gr.Button(
"Submit",
size="sm",
)
cnet_model.change(
fn=update_cn_input,
inputs=[
cnet_model,
canvas_width,
canvas_height,
stencils,
images,
preprocessed_hints,
],
outputs=[
cnet_input,
cnet_output,
canvas_width,
canvas_height,
make_canvas,
use_input_img,
stencils,
images,
preprocessed_hints,
],
)
make_canvas.click(
create_canvas,
[canvas_width, canvas_height],
[
cnet_input,
],
)
gr.on(
triggers=[cnet_gen.click],
fn=cnet_preview,
inputs=[
cnet_model,
cnet_input,
stencils,
images,
preprocessed_hints,
],
outputs=[
cnet_output,
stencils,
images,
preprocessed_hints,
],
)
cnet_rows.value.append(cnet_row)
num_cnets.change(
show_controlnets,
[num_cnets],
[cnet_rows, stencils, images, preprocessed_hints],
use_input_img.click(
import_original,
[sd_init_image, canvas_width, canvas_height],
[cnet_input],
)
cnet_model.change(
fn=update_cn_input,
inputs=[
cnet_model,
stencil,
preprocessed_hint,
],
outputs=[
cnet_input,
cnet_output,
canvas_width,
canvas_height,
make_canvas,
use_input_img,
stencil,
preprocessed_hint,
],
)
make_canvas.click(
create_canvas,
[canvas_width, canvas_height],
[
cnet_input,
],
)
gr.on(
triggers=[cnet_gen.click],
fn=cnet_preview,
inputs=[
cnet_model,
cnet_input,
stencil,
preprocessed_hint,
],
outputs=[
cnet_output,
stencil,
preprocessed_hint,
],
)
use_result.click(
fn=submit_to_cnet_config,
inputs=[
stencil,
preprocessed_hint,
cnet_strength,
cnet_config,
],
outputs=[
cnet_config,
]
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -593,6 +614,30 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Group():
sd_json = gr.JSON()
with gr.Row():
clear_sd_config = gr.ClearButton(
value="Clear Config", size="sm"
)
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
)
load_sd_config = gr.FileExplorer(
label="Load Config",
root=cmd_opts.configs_path if cmd_opts.configs_path else get_configs_path(),
height=75,
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
kwargs = dict(
fn=shark_sd_fn,
@@ -614,21 +659,16 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
sd_custom_vae,
precision,
device,
loras,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
stencils,
images,
preprocessed_hints,
sd_json,
],
outputs=[
sd_gallery,
std_output,
sd_status,
stencils,
images,
],
show_progress="minimal",
)
@@ -648,3 +688,15 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
gr.on(
triggers=[submit_cnet.click],
fn=submit_to_main_config,
inputs=[cnet_config, sd_json],
outputs=[sd_json],
)
gr.on(
triggers=[submit_embeddings.click],
fn=submit_to_main_config,
inputs=[embeddings_config, sd_json],
outputs=[sd_json],
)

View File

@@ -0,0 +1,77 @@
import os
import sys
import glob
import datetime as dt
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
checkpoints_filetypes = (
"*.ckpt",
"*.safetensors",
)
def get_path_stem(path):
path = Path(path)
return path.stem
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
result = Path(os.path.join(base_path, relative_path)).resolve(strict=False)
return result
def get_configs_path() -> Path:
configs = get_resource_path(os.path.join("..", "configs"))
if not os.path.exists(configs):
os.mkdir(configs)
return Path(
get_resource_path("../configs")
)
def get_generated_imgs_path() -> Path:
return Path(
cmd_opts.output_dir
if cmd_opts.output_dir
else get_resource_path("../generated_imgs")
)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
def create_checkpoint_folders():
dir = ["vae", "lora"]
if not cmd_opts.ckpt_dir:
dir.insert(0, "models")
else:
if not os.path.isdir(cmd_opts.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, "
f"{cmd_opts.ckpt_dir} folder does not exists."
)
for root in dir:
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
def get_checkpoints_path(model=""):
return get_resource_path(f"../models/{model}")
def get_checkpoints(model="models"):
ckpt_files = []
file_types = checkpoints_filetypes
if model == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_checkpoints_path(model), extn))
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_checkpoint_pathfile(checkpoint_name, model="models"):
return os.path.join(get_checkpoints_path(model), checkpoint_name)

View File

@@ -1,6 +1,6 @@
import re
from pathlib import Path
from apps.shark_studio.api.utils import (
from apps.shark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
)
from apps.shark_studio.api.sd import (