mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
Migrating to JSON requests in SD UI
This commit is contained in:
@@ -1,5 +1,15 @@
|
||||
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
|
||||
|
||||
import os
|
||||
import PIL
|
||||
import numpy as np
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
)
|
||||
from datetime import datetime
|
||||
from PIL import Image
|
||||
from gradio.components.image_editor import (
|
||||
EditorValue,
|
||||
)
|
||||
|
||||
class control_adapter:
|
||||
def __init__(
|
||||
@@ -57,20 +67,26 @@ class PreprocessorModel:
|
||||
def __init__(
|
||||
self,
|
||||
hf_model_id,
|
||||
device,
|
||||
device = "cpu",
|
||||
):
|
||||
self.model = None
|
||||
self.model = hf_model_id
|
||||
self.device = device
|
||||
|
||||
def compile(self, device):
|
||||
def compile(self):
|
||||
print("compile not implemented for preprocessor.")
|
||||
return
|
||||
|
||||
def run(self, inputs):
|
||||
print("run not implemented for preprocessor.")
|
||||
return
|
||||
return inputs
|
||||
|
||||
|
||||
def cnet_preview(model, input_img, stencils, images, preprocessed_hints):
|
||||
def cnet_preview(model, input_image, stencil, preprocessed_hint):
|
||||
curr_datetime = datetime.now().strftime('%Y-%m-%d.%H-%M-%S')
|
||||
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
|
||||
if not os.path.exists(control_imgs_path):
|
||||
os.mkdir(control_imgs_path)
|
||||
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
|
||||
if isinstance(input_image, PIL.Image.Image):
|
||||
img_dict = {
|
||||
"background": None,
|
||||
@@ -78,57 +94,52 @@ def cnet_preview(model, input_img, stencils, images, preprocessed_hints):
|
||||
"composite": input_image,
|
||||
}
|
||||
input_image = EditorValue(img_dict)
|
||||
images[index] = input_image
|
||||
preprocessed_hint = img_dest
|
||||
if model:
|
||||
stencils[index] = model
|
||||
stencil = model
|
||||
match model:
|
||||
case "canny":
|
||||
canny = CannyDetector()
|
||||
canny = PreprocessorModel("canny")
|
||||
result = canny(
|
||||
np.array(input_image["composite"]),
|
||||
100,
|
||||
200,
|
||||
)
|
||||
preprocessed_hints[index] = Image.fromarray(result)
|
||||
Image.fromarray(result).save(fp=img_dest)
|
||||
return (
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
)
|
||||
case "openpose":
|
||||
openpose = OpenposeDetector()
|
||||
openpose = PreprocessorModel("openpose")
|
||||
result = openpose(np.array(input_image["composite"]))
|
||||
preprocessed_hints[index] = Image.fromarray(result[0])
|
||||
Image.fromarray(result[0]).save(fp=img_dest)
|
||||
return (
|
||||
Image.fromarray(result[0]),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
)
|
||||
case "zoedepth":
|
||||
zoedepth = ZoeDetector()
|
||||
zoedepth = PreprocessorModel("ZoeDepth")
|
||||
result = zoedepth(np.array(input_image["composite"]))
|
||||
preprocessed_hints[index] = Image.fromarray(result)
|
||||
Image.fromarray(result).save(fp=img_dest)
|
||||
return (
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
)
|
||||
case "scribble":
|
||||
preprocessed_hints[index] = input_image["composite"]
|
||||
input_image["composite"].save(fp=img_dest)
|
||||
return (
|
||||
input_image["composite"],
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
)
|
||||
case _:
|
||||
preprocessed_hints[index] = None
|
||||
preprocessed_hint = None
|
||||
return (
|
||||
None,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from shark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
load_vmfb_using_mmap,
|
||||
)
|
||||
from apps.shark_studio.api.utils import get_resource_path
|
||||
from apps.shark_studio.web.utils.file_utils import get_resource_path
|
||||
import iree.runtime as ireert
|
||||
from itertools import chain
|
||||
import gc
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from turbine_models.custom_models.sd_inference import clip, unet, vae
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module
|
||||
from apps.shark_studio.api.utils import get_resource_path
|
||||
from apps.shark_studio.api.controlnet import control_adapter_map
|
||||
from apps.shark_studio.web.utils.state import status_label
|
||||
from apps.shark_studio.modules.pipeline import SharkPipelineBase
|
||||
import iree.runtime as ireert
|
||||
from apps.shark_studio.modules.img_processing import resize_stencil, save_output_img
|
||||
from math import ceil
|
||||
import gc
|
||||
import torch
|
||||
import gradio as gr
|
||||
import PIL
|
||||
import time
|
||||
|
||||
sd_model_map = {
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
@@ -154,7 +155,7 @@ def shark_sd_fn(
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: str | int,
|
||||
seeds: list,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -168,24 +169,13 @@ def shark_sd_fn(
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
control_mode: str,
|
||||
stencils: list,
|
||||
images: list,
|
||||
preprocessed_hints: list,
|
||||
sd_json: dict,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
|
||||
for i, stencil in enumerate(stencils):
|
||||
if images[i] is None and stencil is not None:
|
||||
continue
|
||||
elif stencil is None and any(
|
||||
img is not None for img in [images[i], preprocessed_hints[i]]
|
||||
):
|
||||
images[i] = None
|
||||
preprocessed_hints[i] = None
|
||||
elif images[i] is not None:
|
||||
if isinstance(images[i], dict):
|
||||
images[i] = images[i]["composite"]
|
||||
images[i] = images[i].convert("RGB")
|
||||
stencils=[]
|
||||
preprocessed_hints=[]
|
||||
cnet_strengths=[]
|
||||
|
||||
if isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
@@ -203,8 +193,6 @@ def shark_sd_fn(
|
||||
is_img2img = True
|
||||
print("Performing Stable Diffusion Pipeline setup...")
|
||||
|
||||
device_id = None
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
@@ -216,11 +204,11 @@ def shark_sd_fn(
|
||||
if stencils:
|
||||
for i, stencil in enumerate(stencils):
|
||||
if "xl" not in base_model_id.lower():
|
||||
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
|
||||
custom_model_map[f"control_adapter_{i}"] = control_adapter_map[
|
||||
"runwayml/stable-diffusion-v1-5"
|
||||
][stencil]
|
||||
else:
|
||||
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
|
||||
custom_model_map[f"control_adapter_{i}"] = control_adapter_map[
|
||||
"stabilityai/stable-diffusion-xl-1.0"
|
||||
][stencil]
|
||||
|
||||
@@ -245,54 +233,70 @@ def shark_sd_fn(
|
||||
"steps": steps,
|
||||
"strength": strength,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
"seeds": seeds,
|
||||
"ondemand": ondemand,
|
||||
"repeatable_seeds": repeatable_seeds,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"preprocessed_hints": preprocessed_hints,
|
||||
}
|
||||
|
||||
global sd_pipe
|
||||
global sd_pipe_kwargs
|
||||
|
||||
if sd_pipe_kwargs and sd_pipe_kwargs != submit_pipe_kwargs:
|
||||
sd_pipe = None
|
||||
sd_pipe_kwargs = submit_pipe_kwargs
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
||||
):
|
||||
print("Regenerating pipeline...")
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
if sd_pipe is None:
|
||||
history[-1][-1] = "Getting the pipeline ready..."
|
||||
yield history, ""
|
||||
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
|
||||
|
||||
# Initializes the pipeline and retrieves IR based on all
|
||||
# parameters that are static in the turbine output format,
|
||||
# which is currently MLIR in the torch dialect.
|
||||
|
||||
sd_pipe = SharkStableDiffusionPipeline(
|
||||
sd_pipe = StableDiffusion(
|
||||
**submit_pipe_kwargs,
|
||||
)
|
||||
global_obj.set_sd_obj(sd_pipe)
|
||||
|
||||
sd_pipe.prepare_pipe(**submit_prep_kwargs)
|
||||
generated_imgs = []
|
||||
|
||||
for prompt, msg, exec_time in progress.tqdm(
|
||||
out_imgs=sd_pipe.generate_images(**submit_run_kwargs),
|
||||
desc="Generating Image...",
|
||||
):
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
ceil(steps / strength),
|
||||
strength,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
stencils,
|
||||
resample_type=resample_type,
|
||||
control_mode=control_mode,
|
||||
preprocessed_hints=preprocessed_hints,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = []
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
# break
|
||||
# else:
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
seeds[current_batch],
|
||||
extra_info,
|
||||
sd_json,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
), stencils, images
|
||||
"Image-to-Image", current_batch + 1, batch_count, batch_size
|
||||
), stencils
|
||||
|
||||
return generated_imgs, text_output, "", stencils, images
|
||||
return generated_imgs, text_output, "", stencil, image
|
||||
|
||||
return generated_imgs, text_output, "", stencil, image
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import glob
|
||||
import json
|
||||
from random import (
|
||||
randint,
|
||||
@@ -11,7 +8,6 @@ from random import (
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
@@ -22,11 +18,6 @@ from shark.iree_utils.vulkan_utils import (
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
checkpoints_filetypes = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
)
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
@@ -109,6 +100,7 @@ def set_init_device_flags():
|
||||
elif "metal" in cmd_opts.device:
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_metal_target_platform:
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
|
||||
@@ -150,59 +142,6 @@ def get_all_devices(driver_name):
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
return Path(
|
||||
cmd_opts.output_dir
|
||||
if cmd_opts.output_dir
|
||||
else get_resource_path("..\web\generated_imgs")
|
||||
)
|
||||
|
||||
|
||||
def get_generated_imgs_todays_subdir() -> str:
|
||||
return dt.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
def create_checkpoint_folders():
|
||||
dir = ["vae", "lora"]
|
||||
if not cmd_opts.ckpt_dir:
|
||||
dir.insert(0, "models")
|
||||
else:
|
||||
if not os.path.isdir(cmd_opts.ckpt_dir):
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, "
|
||||
f"{args.ckpt_dir} folder does not exists."
|
||||
)
|
||||
for root in dir:
|
||||
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_checkpoints_path(model=""):
|
||||
return get_resource_path(f"..\web\models\{model}")
|
||||
|
||||
|
||||
def get_checkpoints(model="models"):
|
||||
ckpt_files = []
|
||||
file_types = checkpoints_filetypes
|
||||
if model == "lora":
|
||||
file_types = file_types + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
files = [
|
||||
os.path.basename(x)
|
||||
for x in glob.glob(os.path.join(get_checkpoints_path(model), extn))
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_checkpoint_pathfile(checkpoint_name, model="models"):
|
||||
return os.path.join(get_checkpoints_path(model), checkpoint_name)
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
@@ -250,6 +189,7 @@ def get_opt_flags(model, precision="fp16"):
|
||||
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
)
|
||||
if "rocm" in cmd_opts.device:
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
if cmd_opts.iree_constant_folding == False:
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import safetensors
|
||||
from dataclasses import dataclass
|
||||
from safetensors.torch import load_file
|
||||
from apps.shark_studio.api.utils import get_checkpoint_pathfile
|
||||
from apps.shark_studio.web.utils.file_utils import get_checkpoint_pathfile, get_path_stem
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,115 +1,12 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import datetime as dt
|
||||
import json
|
||||
from csv import DictWriter
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info=None):
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
generated_imgs_path = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", cmd_opts.prompts[0][:15])
|
||||
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
|
||||
|
||||
img_model = cmd_opts.hf_model_id
|
||||
if cmd_opts.ckpt_loc:
|
||||
img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem
|
||||
|
||||
img_vae = None
|
||||
if cmd_opts.custom_vae:
|
||||
img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem
|
||||
|
||||
img_lora = None
|
||||
if cmd_opts.use_lora:
|
||||
img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem
|
||||
|
||||
if cmd_opts.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if cmd_opts.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
if cmd_opts.use_hiresfix:
|
||||
png_size_text = (
|
||||
f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
|
||||
)
|
||||
else:
|
||||
png_size_text = f"{cmd_opts.width}x{cmd_opts.height}"
|
||||
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{cmd_opts.prompts[0]}"
|
||||
f"\nNegative prompt: {cmd_opts.negative_prompts[0]}"
|
||||
f"\nSteps: {cmd_opts.steps},"
|
||||
f"Sampler: {cmd_opts.scheduler}, "
|
||||
f"CFG scale: {cmd_opts.guidance_scale}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_lora}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if cmd_opts.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {cmd_opts.output_img_format} is not "
|
||||
f"supported yet. Image saved as png instead."
|
||||
f"Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
||||
# importance for each data point. Something to consider.
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": cmd_opts.scheduler,
|
||||
"PROMPT": cmd_opts.prompts[0],
|
||||
"NEG_PROMPT": cmd_opts.negative_prompts[0],
|
||||
"SEED": img_seed,
|
||||
"CFG_SCALE": cmd_opts.guidance_scale,
|
||||
"PRECISION": cmd_opts.precision,
|
||||
"STEPS": cmd_opts.steps,
|
||||
"HEIGHT": cmd_opts.height
|
||||
if not cmd_opts.use_hiresfix
|
||||
else cmd_opts.hiresfix_height,
|
||||
"WIDTH": cmd_opts.width
|
||||
if not cmd_opts.use_hiresfix
|
||||
else cmd_opts.hiresfix_width,
|
||||
"MAX_LENGTH": cmd_opts.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
"LORA": img_lora,
|
||||
}
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
csv_mode = "a" if os.path.isfile(csv_path) else "w"
|
||||
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
if csv_mode == "w":
|
||||
dictwriter_obj.writeheader()
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
if cmd_opts.save_metadata_to_json:
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
@@ -122,6 +19,96 @@ resamplers = {
|
||||
resampler_list = resamplers.keys()
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info=None):
|
||||
|
||||
from apps.shark_studio.web.utils.file_utils import get_generated_imgs_path, get_generated_imgs_todays_subdir
|
||||
|
||||
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
generated_imgs_path = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompts"][0][:15])
|
||||
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
|
||||
|
||||
#img_model = cmd_opts.hf_model_id
|
||||
#if cmd_opts.ckpt_loc:
|
||||
# img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem
|
||||
|
||||
#img_vae = None
|
||||
#if cmd_opts.custom_vae:
|
||||
# img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem
|
||||
|
||||
#img_lora = None
|
||||
#if cmd_opts.use_lora:
|
||||
# img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem
|
||||
|
||||
#if cmd_opts.output_img_format == "jpg":
|
||||
# out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
# output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
#else:
|
||||
# out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
# pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
# if cmd_opts.write_metadata_to_png:
|
||||
# # Using a conditional expression caused problems, so setting a new
|
||||
# # variable for now.
|
||||
# if cmd_opts.use_hiresfix:
|
||||
# png_size_text = (
|
||||
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
|
||||
# )
|
||||
# else:
|
||||
# png_size_text = f"{cmd_opts.width}x{cmd_opts.height}"
|
||||
|
||||
# pngInfo.add_text(
|
||||
# "parameters",
|
||||
# f"{cmd_opts.prompts[0]}"
|
||||
# f"\nNegative prompt: {cmd_opts.negative_prompts[0]}"
|
||||
# f"\nSteps: {cmd_opts.steps},"
|
||||
# f"Sampler: {cmd_opts.scheduler}, "
|
||||
# f"CFG scale: {cmd_opts.guidance_scale}, "
|
||||
# f"Seed: {img_seed},"
|
||||
# f"Size: {png_size_text}, "
|
||||
# f"Model: {img_model}, "
|
||||
# f"VAE: {img_vae}, "
|
||||
# f"LoRA: {img_lora}",
|
||||
# )
|
||||
|
||||
# output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
# if cmd_opts.output_img_format not in ["png", "jpg"]:
|
||||
# print(
|
||||
# f"[ERROR] Format {cmd_opts.output_img_format} is not "
|
||||
# f"supported yet. Image saved as png instead."
|
||||
# f"Supported formats: png / jpg"
|
||||
# )
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
||||
# importance for each data point. Something to consider.
|
||||
new_entry = {
|
||||
}
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
csv_mode = "a" if os.path.isfile(csv_path) else "w"
|
||||
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
if csv_mode == "w":
|
||||
dictwriter_obj.writeheader()
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
# For stencil, the input image can be of any size, but we need to ensure that
|
||||
# it conforms with our model constraints :-
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
|
||||
@@ -675,6 +675,13 @@ p.add_argument(
|
||||
"images under --output_dir in the UI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--configs_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to .json config directory."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery_followlinks",
|
||||
default=False,
|
||||
|
||||
@@ -135,7 +135,7 @@ def webui():
|
||||
clear_tmp_mlir,
|
||||
clear_tmp_imgs,
|
||||
)
|
||||
from apps.shark_studio.api.utils import (
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
create_checkpoint_folders,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import sys
|
||||
from PIL import Image
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.api.utils import (
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
import gradio as gr
|
||||
import PIL
|
||||
import json
|
||||
import sys
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
from math import ceil
|
||||
from inspect import signature
|
||||
@@ -12,15 +10,17 @@ from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
from gradio.components.image_editor import (
|
||||
Brush,
|
||||
Eraser,
|
||||
EditorValue,
|
||||
)
|
||||
|
||||
from apps.shark_studio.api.utils import (
|
||||
get_available_devices,
|
||||
)
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
get_checkpoints_path,
|
||||
get_checkpoints,
|
||||
get_configs_path,
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
sd_model_map,
|
||||
@@ -28,8 +28,6 @@ from apps.shark_studio.api.sd import (
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.shark_studio.api.controlnet import (
|
||||
preprocessor_model_map,
|
||||
PreprocessorModel,
|
||||
cnet_preview,
|
||||
)
|
||||
from apps.shark_studio.modules.schedulers import (
|
||||
@@ -44,7 +42,6 @@ from apps.shark_studio.web.ui.utils import (
|
||||
nodlogo_loc,
|
||||
)
|
||||
from apps.shark_studio.web.utils.state import (
|
||||
get_generation_text_info,
|
||||
status_label,
|
||||
)
|
||||
from apps.shark_studio.web.ui.common_events import lora_changed
|
||||
@@ -57,32 +54,58 @@ def view_json_file(file_obj):
|
||||
return content
|
||||
|
||||
|
||||
max_controlnets = 3
|
||||
max_loras = 5
|
||||
def submit_to_cnet_config(stencil: str, preprocessed_hint: str, cnet_strength: int, curr_config: dict):
|
||||
if any(i in [None, ""] for i in [stencil, preprocessed_hint]):
|
||||
return gr.update()
|
||||
if curr_config is not None:
|
||||
if "controlnets" in curr_config:
|
||||
curr_config["controlnets"]["model"].append(stencil)
|
||||
curr_config["controlnets"]["hint"].append(preprocessed_hint)
|
||||
curr_config["controlnets"]["strength"].append(cnet_strength)
|
||||
return curr_config
|
||||
|
||||
cnet_map = {}
|
||||
cnet_map["controlnets"] = {
|
||||
"model": [stencil],
|
||||
"hint": [preprocessed_hint],
|
||||
"strength": [cnet_strength],
|
||||
}
|
||||
return cnet_map
|
||||
|
||||
|
||||
def show_loras(k):
|
||||
k = int(k)
|
||||
return gr.State(
|
||||
[gr.Dropdown(visible=True)] * k
|
||||
+ [gr.Dropdown(visible=False, value="None")] * (max_loras - k)
|
||||
)
|
||||
def update_embeddings_json(embedding, curr_config: dict):
|
||||
if curr_config is not None:
|
||||
if "embeddings" in curr_config:
|
||||
curr_config["embeddings"].append(embedding)
|
||||
return curr_config
|
||||
|
||||
config = {"embeddings": [embedding]}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def show_controlnets(k):
|
||||
k = int(k)
|
||||
return [
|
||||
gr.State(
|
||||
[
|
||||
[gr.Row(visible=True, render=True)] * k
|
||||
+ [gr.Row(visible=False)] * (max_controlnets - k)
|
||||
]
|
||||
),
|
||||
gr.State([None] * k),
|
||||
gr.State([None] * k),
|
||||
gr.State([None] * k),
|
||||
]
|
||||
def submit_to_main_config(input_cfg: dict, main_cfg: dict):
|
||||
if main_cfg in [None, ""]:
|
||||
# only time main_cfg should be a string is empty case.
|
||||
return input_cfg
|
||||
|
||||
for base_key in input_cfg:
|
||||
main_cfg[base_key] = input_cfg[base_key]
|
||||
return main_cfg
|
||||
|
||||
|
||||
def save_sd_cfg(config: dict, save_name: str):
|
||||
if os.path.exists(save_name):
|
||||
filepath=save_name
|
||||
elif cmd_opts.configs_path:
|
||||
filepath=os.path.join(cmd_opts.configs_path, save_name)
|
||||
else:
|
||||
filepath=os.path.join(get_configs_path(), save_name)
|
||||
if ".json" not in filepath:
|
||||
filepath += ".json"
|
||||
with open(filepath, mode="w") as f:
|
||||
f.write(json.dumps(config))
|
||||
return("...")
|
||||
|
||||
def create_canvas(width, height):
|
||||
data = Image.fromarray(
|
||||
@@ -101,30 +124,34 @@ def create_canvas(width, height):
|
||||
|
||||
|
||||
def import_original(original_img, width, height):
|
||||
resized_img, _, _ = resize_stencil(original_img, width, height)
|
||||
img_dict = {
|
||||
"background": resized_img,
|
||||
"layers": [resized_img],
|
||||
"composite": None,
|
||||
}
|
||||
return gr.ImageEditor(
|
||||
value=EditorValue(img_dict),
|
||||
crop_size=(width, height),
|
||||
)
|
||||
if original_img is None:
|
||||
resized_img = create_canvas(width, height)
|
||||
return gr.ImageEditor(
|
||||
value=resized_img,
|
||||
crop_size=(width, height),
|
||||
)
|
||||
else:
|
||||
resized_img, _, _ = resize_stencil(original_img, width, height)
|
||||
img_dict = {
|
||||
"background": resized_img,
|
||||
"layers": [resized_img],
|
||||
"composite": None,
|
||||
}
|
||||
return gr.ImageEditor(
|
||||
value=EditorValue(img_dict),
|
||||
crop_size=(width, height),
|
||||
)
|
||||
|
||||
|
||||
def update_cn_input(
|
||||
model,
|
||||
width,
|
||||
height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
):
|
||||
print("update_cn_input")
|
||||
if model == None:
|
||||
stencils[index] = None
|
||||
images[index] = None
|
||||
preprocessed_hints[index] = None
|
||||
stencil = None
|
||||
preprocessed_hint = None
|
||||
return [
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
@@ -132,9 +159,8 @@ def update_cn_input(
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
]
|
||||
elif model == "scribble":
|
||||
return [
|
||||
@@ -160,9 +186,8 @@ def update_cn_input(
|
||||
gr.Slider(visible=True, label="Canvas Height"),
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
@@ -181,11 +206,10 @@ def update_cn_input(
|
||||
),
|
||||
gr.Slider(visible=True, label="Canvas Width"),
|
||||
gr.Slider(visible=True, label="Canvas Height"),
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
gr.Button(visible=True),
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
]
|
||||
|
||||
|
||||
@@ -230,7 +254,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
choices=sd_model_map.keys(),
|
||||
) # base_model_id
|
||||
sd_custom_weights = gr.Dropdown(
|
||||
label="Weights (Optional)",
|
||||
label="Custom Weights",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
@@ -253,17 +277,21 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
save_sd_config = gr.Button(
|
||||
value="Save Config", size="sm"
|
||||
)
|
||||
clear_sd_config = gr.ClearButton(
|
||||
value="Clear Config", size="sm"
|
||||
)
|
||||
load_sd_config = gr.FileExplorer(
|
||||
label="Load Config",
|
||||
root=os.path.basename("./configs"),
|
||||
)
|
||||
with gr.Row():
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=cmd_opts.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
@@ -288,40 +316,40 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="Embeddings options", open=False, render=True
|
||||
label="Embeddings options", open=True, render=True
|
||||
):
|
||||
sd_lora_info = (
|
||||
str(get_checkpoints_path("loras"))
|
||||
).replace("\\", "\n\\")
|
||||
num_loras = gr.Slider(
|
||||
1, max_loras, value=1, step=1, label="LoRA Count"
|
||||
)
|
||||
loras = gr.State([])
|
||||
for i in range(max_loras):
|
||||
with gr.Row():
|
||||
lora_opt = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=sd_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_checkpoints("lora"),
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
gr.on(
|
||||
triggers=[lora_opt.change],
|
||||
fn=lora_changed,
|
||||
inputs=[lora_opt],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
).replace("\\", "\n\\")
|
||||
with gr.Column(scale=2):
|
||||
lora_opt = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=sd_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_checkpoints("lora"),
|
||||
)
|
||||
loras.value.append(lora_opt)
|
||||
|
||||
num_loras.change(show_loras, [num_loras], [loras])
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
embeddings_config = gr.JSON()
|
||||
submit_embeddings = gr.Button("Submit to Main Config", size="sm")
|
||||
gr.on(
|
||||
triggers=[lora_opt.change],
|
||||
fn=lora_changed,
|
||||
inputs=[lora_opt],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
gr.on(
|
||||
triggers=[lora_opt.change],
|
||||
fn=update_embeddings_json,
|
||||
inputs=[lora_opt, embeddings_config],
|
||||
outputs=[embeddings_config],
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=True):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -331,7 +359,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
choices=scheduler_model_map.keys(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
@@ -397,20 +424,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=cmd_opts.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=cmd_opts.seed,
|
||||
@@ -425,49 +438,52 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="Controlnet Options", open=False, render=False
|
||||
label="Controlnet Options", open=False, render=True
|
||||
):
|
||||
sd_cnet_info = (
|
||||
str(get_checkpoints_path("controlnet"))
|
||||
).replace("\\", "\n\\")
|
||||
num_cnets = gr.Slider(
|
||||
0,
|
||||
max_controlnets,
|
||||
value=0,
|
||||
step=1,
|
||||
label="Controlnet Count",
|
||||
)
|
||||
cnet_rows = []
|
||||
stencils = gr.State([])
|
||||
images = gr.State([])
|
||||
preprocessed_hints = gr.State([])
|
||||
control_mode = gr.Radio(
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
value="Balanced",
|
||||
label="Control Mode",
|
||||
)
|
||||
with gr.Column():
|
||||
cnet_config = gr.JSON()
|
||||
submit_cnet = gr.Button("Submit to Main Config", size="sm")
|
||||
with gr.Column():
|
||||
sd_cnet_info = (
|
||||
str(get_checkpoints_path("controlnet"))
|
||||
).replace("\\", "\n\\")
|
||||
stencil = gr.State("")
|
||||
preprocessed_hint = gr.State("")
|
||||
with gr.Row():
|
||||
control_mode = gr.Radio(
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
value="Balanced",
|
||||
label="Control Mode",
|
||||
)
|
||||
|
||||
for i in range(max_controlnets):
|
||||
with gr.Row(visible=False) as cnet_row:
|
||||
with gr.Column():
|
||||
cnet_gen = gr.Button(
|
||||
value="Preprocess controlnet input",
|
||||
)
|
||||
cnet_model = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Controlnet Model",
|
||||
info=sd_cnet_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=[
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
]
|
||||
+ get_checkpoints("controlnet"),
|
||||
)
|
||||
with gr.Row(visible=True) as cnet_row:
|
||||
with gr.Column():
|
||||
cnet_gen = gr.Button(
|
||||
value="Preprocess controlnet input",
|
||||
)
|
||||
cnet_model = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Controlnet Model",
|
||||
info=sd_cnet_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=[
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
]
|
||||
+ get_checkpoints("controlnet"),
|
||||
)
|
||||
cnet_strength = gr.Slider(
|
||||
label="Controlnet Strength",
|
||||
minimum=0,
|
||||
maximum=100,
|
||||
value=50,
|
||||
step=1,
|
||||
)
|
||||
with gr.Row():
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
@@ -484,22 +500,24 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
use_input_img = gr.Button(
|
||||
value="Use Original Image",
|
||||
visible=False,
|
||||
)
|
||||
cnet_input = gr.ImageEditor(
|
||||
visible=True,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
use_input_img = gr.Button(
|
||||
value="Use Original Image",
|
||||
visible=False,
|
||||
size="sm",
|
||||
)
|
||||
cnet_input = gr.ImageEditor(
|
||||
visible=True,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
)
|
||||
with gr.Column():
|
||||
cnet_output = gr.Image(
|
||||
value=None,
|
||||
visible=True,
|
||||
@@ -507,63 +525,66 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
use_input_img.click(
|
||||
import_original,
|
||||
[sd_init_image, canvas_width, canvas_height],
|
||||
[cnet_input],
|
||||
use_result = gr.Button(
|
||||
"Submit",
|
||||
size="sm",
|
||||
)
|
||||
cnet_model.change(
|
||||
fn=update_cn_input,
|
||||
inputs=[
|
||||
cnet_model,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_input,
|
||||
cnet_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
use_input_img,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
make_canvas.click(
|
||||
create_canvas,
|
||||
[canvas_width, canvas_height],
|
||||
[
|
||||
cnet_input,
|
||||
],
|
||||
)
|
||||
gr.on(
|
||||
triggers=[cnet_gen.click],
|
||||
fn=cnet_preview,
|
||||
inputs=[
|
||||
cnet_model,
|
||||
cnet_input,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_output,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
cnet_rows.value.append(cnet_row)
|
||||
|
||||
num_cnets.change(
|
||||
show_controlnets,
|
||||
[num_cnets],
|
||||
[cnet_rows, stencils, images, preprocessed_hints],
|
||||
use_input_img.click(
|
||||
import_original,
|
||||
[sd_init_image, canvas_width, canvas_height],
|
||||
[cnet_input],
|
||||
)
|
||||
cnet_model.change(
|
||||
fn=update_cn_input,
|
||||
inputs=[
|
||||
cnet_model,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
],
|
||||
outputs=[
|
||||
cnet_input,
|
||||
cnet_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
use_input_img,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
],
|
||||
)
|
||||
make_canvas.click(
|
||||
create_canvas,
|
||||
[canvas_width, canvas_height],
|
||||
[
|
||||
cnet_input,
|
||||
],
|
||||
)
|
||||
gr.on(
|
||||
triggers=[cnet_gen.click],
|
||||
fn=cnet_preview,
|
||||
inputs=[
|
||||
cnet_model,
|
||||
cnet_input,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
],
|
||||
outputs=[
|
||||
cnet_output,
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
],
|
||||
)
|
||||
use_result.click(
|
||||
fn=submit_to_cnet_config,
|
||||
inputs=[
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
cnet_strength,
|
||||
cnet_config,
|
||||
],
|
||||
outputs=[
|
||||
cnet_config,
|
||||
]
|
||||
)
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -593,6 +614,30 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Group():
|
||||
sd_json = gr.JSON()
|
||||
with gr.Row():
|
||||
clear_sd_config = gr.ClearButton(
|
||||
value="Clear Config", size="sm"
|
||||
)
|
||||
save_sd_config = gr.Button(
|
||||
value="Save Config", size="sm"
|
||||
)
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
)
|
||||
load_sd_config = gr.FileExplorer(
|
||||
label="Load Config",
|
||||
root=cmd_opts.configs_path if cmd_opts.configs_path else get_configs_path(),
|
||||
height=75,
|
||||
)
|
||||
save_sd_config.click(
|
||||
fn=save_sd_cfg,
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=shark_sd_fn,
|
||||
@@ -614,21 +659,16 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
sd_custom_vae,
|
||||
precision,
|
||||
device,
|
||||
loras,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
sd_json,
|
||||
],
|
||||
outputs=[
|
||||
sd_gallery,
|
||||
std_output,
|
||||
sd_status,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
show_progress="minimal",
|
||||
)
|
||||
@@ -648,3 +688,15 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
gr.on(
|
||||
triggers=[submit_cnet.click],
|
||||
fn=submit_to_main_config,
|
||||
inputs=[cnet_config, sd_json],
|
||||
outputs=[sd_json],
|
||||
)
|
||||
gr.on(
|
||||
triggers=[submit_embeddings.click],
|
||||
fn=submit_to_main_config,
|
||||
inputs=[embeddings_config, sd_json],
|
||||
outputs=[sd_json],
|
||||
)
|
||||
|
||||
77
apps/shark_studio/web/utils/file_utils.py
Normal file
77
apps/shark_studio/web/utils/file_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
checkpoints_filetypes = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
)
|
||||
|
||||
def get_path_stem(path):
|
||||
path = Path(path)
|
||||
return path.stem
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
|
||||
result = Path(os.path.join(base_path, relative_path)).resolve(strict=False)
|
||||
return result
|
||||
|
||||
def get_configs_path() -> Path:
|
||||
configs = get_resource_path(os.path.join("..", "configs"))
|
||||
if not os.path.exists(configs):
|
||||
os.mkdir(configs)
|
||||
return Path(
|
||||
get_resource_path("../configs")
|
||||
)
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
return Path(
|
||||
cmd_opts.output_dir
|
||||
if cmd_opts.output_dir
|
||||
else get_resource_path("../generated_imgs")
|
||||
)
|
||||
|
||||
|
||||
def get_generated_imgs_todays_subdir() -> str:
|
||||
return dt.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
def create_checkpoint_folders():
|
||||
dir = ["vae", "lora"]
|
||||
if not cmd_opts.ckpt_dir:
|
||||
dir.insert(0, "models")
|
||||
else:
|
||||
if not os.path.isdir(cmd_opts.ckpt_dir):
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, "
|
||||
f"{cmd_opts.ckpt_dir} folder does not exists."
|
||||
)
|
||||
for root in dir:
|
||||
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_checkpoints_path(model=""):
|
||||
return get_resource_path(f"../models/{model}")
|
||||
|
||||
|
||||
def get_checkpoints(model="models"):
|
||||
ckpt_files = []
|
||||
file_types = checkpoints_filetypes
|
||||
if model == "lora":
|
||||
file_types = file_types + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
files = [
|
||||
os.path.basename(x)
|
||||
for x in glob.glob(os.path.join(get_checkpoints_path(model), extn))
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_checkpoint_pathfile(checkpoint_name, model="models"):
|
||||
return os.path.join(get_checkpoints_path(model), checkpoint_name)
|
||||
@@ -1,6 +1,6 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from apps.shark_studio.api.utils import (
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
get_checkpoint_pathfile,
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
|
||||
Reference in New Issue
Block a user