Fixing json API

This commit is contained in:
Ean Garvey
2023-12-14 11:00:06 -06:00
parent 961a5adda4
commit 9645b78281
9 changed files with 413 additions and 196 deletions

View File

@@ -1,6 +1,7 @@
from turbine_models.custom_models.sd_inference import clip, unet, vae
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path
from apps.shark_studio.modules.pipeline import SharkPipelineBase
from apps.shark_studio.modules.img_processing import resize_stencil, save_output_img
from math import ceil
@@ -9,6 +10,8 @@ import torch
import gradio as gr
import PIL
import time
import os
import json
sd_model_map = {
"CompVis/stable-diffusion-v1-4": {
@@ -90,6 +93,24 @@ sd_model_map = {
}
def get_spec(custom_sd_map: dict, sd_embeds: dict):
spec = []
for key in custom_sd_map:
if "control" in key.split("_"):
spec.append("controlled")
elif key == "custom_vae":
spec.append(custom_sd_map[key]["custom_weights"].split(".")[0])
num_embeds = 0
embeddings_spec = None
for embed in sd_embeds:
if embed is not None:
num_embeds += 1
embeddings_spec = num_embeds + "embeds"
if embeddings_spec:
spec.append(embeddings_spec)
return "_".join(spec)
class StableDiffusion(SharkPipelineBase):
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
@@ -113,15 +134,17 @@ class StableDiffusion(SharkPipelineBase):
custom_model_map: dict = {},
embeddings: dict = {},
import_ir: bool = True,
is_img2img: bool = False,
):
super().__init__(sd_model_map[base_model_id], device, import_ir)
self.base_model_id = base_model_id
self.device = device
super().__init__(sd_model_map[base_model_id], base_model_id, device, import_ir)
self.precision = precision
self.iree_module_dict = None
self.get_compiled_map()
def prepare_pipeline(self, scheduler, custom_model_map):
self.is_img2img = is_img2img
self.pipe_id = safe_name(base_model_id) + str(height) + str(width) + precision + device + get_spec(custom_model_map, embeddings)
def prepare_pipe(self, scheduler, custom_model_map, embeddings):
print(f"Preparing pipeline with scheduler {scheduler}, custom map {json.dumps(custom_model_map)}, and embeddings {json.dumps(embeddings)}.")
self.get_compiled_map(device=self.device, pipe_id=self.pipe_id)
return None
def generate_images(
@@ -136,26 +159,38 @@ class StableDiffusion(SharkPipelineBase):
repeatable_seeds,
resample_type,
control_mode,
preprocessed_hints,
hints,
):
return None, None, None, None, None
print("Generating Images...")
test_img = [PIL.Image.open(get_resource_path("../../tests/jupiter.png"), mode="r")]
return test_img#, "", ""
# NOTE: Each `hf_model_id` should have its own starting configuration.
# model_vmfb_key = ""
def shark_sd_fn_dict_input(
sd_kwargs: dict,
):
input_imgs=[]
img_paths = sd_kwargs["sd_init_image"]
for img_path in img_paths:
if os.path.isfile(img_path):
input_imgs.append(PIL.Image.open(img_path, mode='r').convert("RGB"))
sd_kwargs["sd_init_image"] = input_imgs
generated_imgs = shark_sd_fn(
**sd_kwargs
)
return generated_imgs, "OK", "OK"
def shark_sd_fn(
prompt,
negative_prompt,
image_dict,
sd_init_image,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seeds: list,
seed: list,
batch_count: int,
batch_size: int,
scheduler: str,
@@ -164,23 +199,17 @@ def shark_sd_fn(
custom_vae: str,
precision: str,
device: str,
lora_weights: str | list,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
sd_json: dict,
progress=gr.Progress(),
controlnets: dict,
embeddings: dict,
):
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
stencils=[]
preprocessed_hints=[]
cnet_strengths=[]
if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
elif image_dict:
image = image_dict["image"].convert("RGB")
sd_kwargs = locals()
if isinstance(sd_init_image, PIL.Image.Image):
image = sd_init_image.convert("RGB")
elif sd_init_image:
image = sd_init_image["image"].convert("RGB")
else:
image = None
is_img2img = False
@@ -197,20 +226,33 @@ def shark_sd_fn(
import apps.shark_studio.web.utils.globals as global_obj
custom_model_map = {}
control_mode = None
hints = []
if custom_weights != "None":
custom_model_map["unet"] = {"custom_weights": custom_weights}
if custom_vae != "None":
custom_model_map["vae"] = {"custom_weights": custom_vae}
if stencils:
for i, stencil in enumerate(stencils):
if "model" in controlnets:
for i, model in enumerate(controlnets["model"]):
if "xl" not in base_model_id.lower():
custom_model_map[f"control_adapter_{i}"] = control_adapter_map[
"runwayml/stable-diffusion-v1-5"
][stencil]
custom_model_map[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map[
"runwayml/stable-diffusion-v1-5"
][model],
"strength": controlnets["strength"][i],
}
else:
custom_model_map[f"control_adapter_{i}"] = control_adapter_map[
"stabilityai/stable-diffusion-xl-1.0"
][stencil]
custom_model_map[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map[
"stabilityai/stable-diffusion-xl-1.0"
][model],
"strength": controlnets["strength"][i]
}
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
print(json.dumps(custom_model_map))
submit_pipe_kwargs = {
"base_model_id": base_model_id,
@@ -219,13 +261,14 @@ def shark_sd_fn(
"precision": precision,
"device": device,
"custom_model_map": custom_model_map,
"embeddings": embeddings,
"import_ir": cmd_opts.import_mlir,
"is_img2img": is_img2img,
}
submit_prep_kwargs = {
"scheduler": scheduler,
"custom_model_map": custom_model_map,
"embeddings": lora_weights,
"embeddings": embeddings,
}
submit_run_kwargs = {
"prompt": prompt,
@@ -233,18 +276,18 @@ def shark_sd_fn(
"steps": steps,
"strength": strength,
"guidance_scale": guidance_scale,
"seeds": seeds,
"seed": seed,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"preprocessed_hints": preprocessed_hints,
"hints": hints,
}
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
):
print("Regenerating pipeline...")
print("Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
@@ -258,27 +301,16 @@ def shark_sd_fn(
)
global_obj.set_sd_obj(sd_pipe)
sd_pipe.prepare_pipe(**submit_prep_kwargs)
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
generated_imgs = []
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
ceil(steps / strength),
strength,
guidance_scale,
seeds[current_batch],
stencils,
resample_type=resample_type,
control_mode=control_mode,
preprocessed_hints=preprocessed_hints,
**submit_run_kwargs
)
total_time = time.time() - start_time
text_output = []
text_output += "\n" + global_obj.get_sd_obj().log
text_output += "\n" # + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
@@ -286,17 +318,15 @@ def shark_sd_fn(
# else:
save_output_img(
out_imgs[0],
seeds[current_batch],
sd_json,
seed[current_batch],
sd_kwargs,
)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Image-to-Image", current_batch + 1, batch_count, batch_size
), stencils
yield generated_imgs#, text_output, status_label(
#"Stable Diffusion", current_batch + 1, batch_count, batch_size
#)
return generated_imgs, text_output, "", stencil, image
return generated_imgs, text_output, "", stencil, image
return generated_imgs#, text_output, ""
def cancel_sd():

View File

@@ -1,11 +1,12 @@
import os
import re
import datetime as dt
import json
from csv import DictWriter
from PIL import Image
from pathlib import Path
from csv import DictWriter
from PIL import Image, PngImagePlugin
from pathlib import Path
from datetime import datetime as dt
from base64 import decode
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
@@ -23,7 +24,7 @@ resampler_list = resamplers.keys()
def save_output_img(output_img, img_seed, extra_info=None):
from apps.shark_studio.web.utils.file_utils import get_generated_imgs_path, get_generated_imgs_todays_subdir
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if extra_info is None:
extra_info = {}
@@ -33,60 +34,63 @@ def save_output_img(output_img, img_seed, extra_info=None):
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompts"][0][:15])
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15])
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
#img_model = cmd_opts.hf_model_id
#if cmd_opts.ckpt_loc:
# img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem
img_model = extra_info["base_model_id"]
if extra_info["custom_weights"] not in [None, "None"]:
img_model = Path(os.path.basename(extra_info["custom_weights"])).stem
#img_vae = None
#if cmd_opts.custom_vae:
# img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem
img_vae = None
if extra_info["custom_vae"]:
img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem
#img_lora = None
#if cmd_opts.use_lora:
# img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem
img_loras = None
if extra_info["embeddings"]:
img_lora = []
for i in extra_info["embeddings"]:
img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem
img_loras = ", ".join(img_lora)
#if cmd_opts.output_img_format == "jpg":
# out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
# output_img.save(out_img_path, quality=95, subsampling=0)
#else:
# out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
# pngInfo = PngImagePlugin.PngInfo()
if cmd_opts.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
# if cmd_opts.write_metadata_to_png:
# # Using a conditional expression caused problems, so setting a new
# # variable for now.
# if cmd_opts.use_hiresfix:
# png_size_text = (
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
# )
# else:
# png_size_text = f"{cmd_opts.width}x{cmd_opts.height}"
if cmd_opts.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
#if cmd_opts.use_hiresfix:
# png_size_text = (
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
# )
#else:
png_size_text = f"{extra_info['width']}x{extra_info['height']}"
# pngInfo.add_text(
# "parameters",
# f"{cmd_opts.prompts[0]}"
# f"\nNegative prompt: {cmd_opts.negative_prompts[0]}"
# f"\nSteps: {cmd_opts.steps},"
# f"Sampler: {cmd_opts.scheduler}, "
# f"CFG scale: {cmd_opts.guidance_scale}, "
# f"Seed: {img_seed},"
# f"Size: {png_size_text}, "
# f"Model: {img_model}, "
# f"VAE: {img_vae}, "
# f"LoRA: {img_lora}",
# )
pngInfo.add_text(
"parameters",
f"{extra_info['prompt'][0]}"
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
f"\nSteps: {extra_info['steps'][0]},"
f"Sampler: {extra_info['scheduler'][0]}, "
f"CFG scale: {extra_info['guidance_scale'][0]}, "
f"Seed: {img_seed},"
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_loras}",
)
# output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
# if cmd_opts.output_img_format not in ["png", "jpg"]:
# print(
# f"[ERROR] Format {cmd_opts.output_img_format} is not "
# f"supported yet. Image saved as png instead."
# f"Supported formats: png / jpg"
# )
if cmd_opts.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {cmd_opts.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
@@ -104,7 +108,6 @@ def save_output_img(output_img, img_seed, extra_info=None):
dictwriter_obj.writerow(new_entry)
csv_obj.close()
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)

View File

@@ -1,4 +1,6 @@
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.web.utils.file_utils import get_checkpoints_path
import gc
class SharkPipelineBase:
@@ -12,33 +14,37 @@ class SharkPipelineBase:
def __init__(
self,
model_map: dict,
base_model_id: str,
device: str,
import_mlir: bool = True,
):
self.model_map = model_map
self.base_model_id = base_model_id
self.device = device
self.import_mlir = import_mlir
self.iree_module_dict = {}
def import_torch_ir(self, submodel, kwargs):
weights = (
submodel["custom_weights"]
if submodel["custom_weights"]
else None
)
torch_ir = self.model_map[submodel]["initializer"](
self.base_model_id, **kwargs, compile_to="torch"
)
self.model_map[submodel]["tempfile_name"] = get_resource_path(
f"{submodel}.torch.tempfile"
)
with open(self.model_map[submodel]["tempfile_name"], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
def import_torch_ir(self, base_model_id):
for submodel in self.model_map:
hf_id = (
submodel["custom_hf_id"]
if submodel["custom_hf_id"]
else base_model_id
)
torch_ir = submodel["initializer"](
hf_id, **submodel["init_kwargs"], compile_to="torch"
)
submodel["tempfile_name"] = get_resource_path(
f"{submodel}.torch.tempfile"
)
with open(submodel["tempfile_name"], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
def load_vmfb(self, submodel):
if self.iree_module_dict[submodel]:
if submodel in self.iree_module_dict:
print(
f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}"
)
@@ -47,25 +53,51 @@ class SharkPipelineBase:
return submodel["vmfb"]
def merge_custom_map(self, custom_model_map):
for submodel in custom_model_map:
for key in submodel:
self.model_map[submodel][key] = key
print(self.model_map)
def get_compiled_map(self, device) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
def get_local_vmfbs(self, pipe_id):
for submodel in self.model_map:
if not self.iree_module_dict[submodel][vmfb]:
vmfbs = []
vmfb_matches = {}
vmfbs_path = get_checkpoints_path("../vmfbs")
for (dirpath, dirnames, filenames) in os.walk(vmfbs_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if all(keys in file for keys in [submodel, pipe_id]):
print(f"Found existing .vmfb at {file}")
self.iree_module_dict[submodel] = {'vmfb': file}
def get_compiled_map(self, device, pipe_id) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
if not self.import_mlir:
self.get_local_vmfbs(pipe_id)
for submodel in self.model_map:
if submodel in self.iree_module_dict:
if "vmfb" in self.iree_module_dict[submodel]:
continue
if "tempfile_name" not in self.model_map[submodel]:
sub_kwargs = self.model_map[submodel]["kwargs"] if self.model_map[submodel]["kwargs"] else {}
import_torch_ir(submodel, self.base_model_id, **sub_kwargs)
self.iree_module_dict[submodel] = get_iree_compiled_module(
submodel.tempfile_name,
submodel["tempfile_name"],
device=self.device,
frontend="torch",
external_weight_file=submodel["custom_weights"]
)
# TODO: delete the temp file
def run(self, submodel, inputs):
return
def safe_name(name):
return name.replace("/", "_").replace("-", "_")

Binary file not shown.

After

Width:  |  Height:  |  Size: 347 KiB

View File

@@ -0,0 +1,24 @@
{
"prompt": [ "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" ],
"negative_prompt": [ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" ],
"sd_init_image": [ "None" ],
"height": 512,
"width": 512,
"steps": [ 50 ],
"strength": [ 0.8 ],
"guidance_scale": [ 7.5 ],
"seed": [ -1 ],
"batch_count": 1,
"batch_size": 1,
"scheduler": [ "EulerDiscrete" ],
"base_model_id": "runwayml/stable-diffusion-v1-5",
"custom_weights": "",
"custom_vae": "",
"precision": "fp16",
"device": "vulkan",
"ondemand": "False",
"repeatable_seeds": "False",
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}

View File

@@ -1 +0,0 @@
{}

View File

@@ -2,8 +2,6 @@ import os
import json
import gradio as gr
import numpy as np
from math import ceil
from inspect import signature
from PIL import Image
from pathlib import Path
@@ -24,7 +22,7 @@ from apps.shark_studio.web.utils.file_utils import (
)
from apps.shark_studio.api.sd import (
sd_model_map,
shark_sd_fn,
shark_sd_fn_dict_input,
cancel_sd,
)
from apps.shark_studio.api.controlnet import (
@@ -47,18 +45,19 @@ from apps.shark_studio.web.utils.state import (
from apps.shark_studio.web.ui.common_events import lora_changed
def view_json_file(file_obj):
def view_json_file(file_path):
content = ""
with open(file_obj.name, "r") as fopen:
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def submit_to_cnet_config(stencil: str, preprocessed_hint: str, cnet_strength: int, curr_config: dict):
def submit_to_cnet_config(stencil: str, preprocessed_hint: str, cnet_strength: int, control_mode: str, curr_config: dict):
if any(i in [None, ""] for i in [stencil, preprocessed_hint]):
return gr.update()
if curr_config is not None:
if "controlnets" in curr_config:
if "controlnets" in curr_config:
curr_config["controlnets"]["control_mode"] = control_mode
curr_config["controlnets"]["model"].append(stencil)
curr_config["controlnets"]["hint"].append(preprocessed_hint)
curr_config["controlnets"]["strength"].append(cnet_strength)
@@ -66,6 +65,7 @@ def submit_to_cnet_config(stencil: str, preprocessed_hint: str, cnet_strength: i
cnet_map = {}
cnet_map["controlnets"] = {
"control_mode": control_mode,
"model": [stencil],
"hint": [preprocessed_hint],
"strength": [cnet_strength],
@@ -85,8 +85,7 @@ def update_embeddings_json(embedding, curr_config: dict):
def submit_to_main_config(input_cfg: dict, main_cfg: dict):
if main_cfg in [None, ""]:
# only time main_cfg should be a string is empty case.
if main_cfg in [None, "", {}]:
return input_cfg
for base_key in input_cfg:
@@ -94,6 +93,75 @@ def submit_to_main_config(input_cfg: dict, main_cfg: dict):
return main_cfg
def pull_sd_configs(
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
ondemand,
repeatable_seeds,
resample_type,
sd_json,
):
sd_args = locals()
for arg in sd_args:
sd_json[arg] = sd_args[arg]
return sd_json
def load_sd_cfg(sd_json: dict, load_sd_config: str):
new_sd_config = json.loads(view_json_file(load_sd_config))
if sd_json:
for key in new_sd_config:
sd_json[key] = new_sd_config[key]
else:
sd_json = new_sd_config
if os.path.isfile(sd_json["sd_init_image"][0]):
sd_image = Image.open(sd_json["sd_init_image"][0], mode='r')
else:
sd_image = None
return [
sd_json["prompt"][0],
sd_json["negative_prompt"][0],
sd_image,
sd_json["height"],
sd_json["width"],
sd_json["steps"][0],
sd_json["strength"][0],
sd_json["guidance_scale"],
sd_json["seed"][0],
sd_json["batch_count"],
sd_json["batch_size"],
sd_json["scheduler"][0],
sd_json["base_model_id"],
sd_json["custom_weights"],
sd_json["custom_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["ondemand"],
sd_json["repeatable_seeds"],
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
sd_json,
]
def save_sd_cfg(config: dict, save_name: str):
if os.path.exists(save_name):
filepath=save_name
@@ -213,15 +281,7 @@ def update_cn_input(
]
sd_fn_inputs = []
sd_fn_sig = signature(shark_sd_fn).replace()
for i in sd_fn_sig.parameters:
sd_fn_inputs.append(i)
with gr.Blocks(title="Stable Diffusion") as sd_element:
# Get a list of arguments needed for the API call, then
# initialize an empty list that will manage the corresponding
# gradio values.
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row(variant="compact", equal_height=True):
@@ -239,34 +299,34 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
show_download_button=False,
)
with gr.Column(elem_id="ui_body"):
with gr.Row():
with gr.Row(variant="compact"):
with gr.Column(scale=1, min_width=600):
with gr.Row(equal_height=True):
with gr.Column(scale=3):
sd_model_info = (
f"Checkpoint Path: {str(get_checkpoints_path())}"
)
sd_base = gr.Dropdown(
base_model_id = gr.Dropdown(
label="Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/stable-diffusion-2-1-base",
choices=sd_model_map.keys(),
) # base_model_id
sd_custom_weights = gr.Dropdown(
custom_weights = gr.Dropdown(
label="Custom Weights",
info="Select or enter HF model ID",
elem_id="custom_model",
value="None",
allow_custom_value=True,
choices=get_checkpoints(sd_base),
choices=get_checkpoints(base_model_id),
) #
with gr.Column(scale=2):
sd_vae_info = (
str(get_checkpoints_path("vae"))
).replace("\\", "\n\\")
sd_vae_info = f"VAE Path: {sd_vae_info}"
sd_custom_vae = gr.Dropdown(
custom_vae = gr.Dropdown(
label=f"Custom VAE Models",
info=sd_vae_info,
elem_id="custom_model",
@@ -513,7 +573,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
visible=True,
image_mode="RGB",
interactive=True,
show_label=True,
show_label=False,
label="Input Image",
type="pil",
)
@@ -558,6 +618,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
[
cnet_input,
],
queue=True,
show_progress=False,
)
gr.on(
triggers=[cnet_gen.click],
@@ -573,6 +635,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
stencil,
preprocessed_hint,
],
queue=True,
show_progress=False,
)
use_result.click(
fn=submit_to_cnet_config,
@@ -580,11 +644,14 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
stencil,
preprocessed_hint,
cnet_strength,
control_mode,
cnet_config,
],
outputs=[
cnet_config,
]
],
queue=True,
show_progress=False,
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -615,32 +682,68 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
)
stop_batch = gr.Button("Stop Batch")
with gr.Group():
sd_json = gr.JSON()
with gr.Row():
with gr.Column(scale=3):
sd_json = gr.JSON(value=view_json_file(os.path.join(get_configs_path(), "default_sd_config.json")))
with gr.Column(scale=1):
clear_sd_config = gr.ClearButton(
value="Clear Config", size="sm"
)
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
value="Clear Config", size="sm", components=sd_json
)
with gr.Row():
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
)
load_sd_config = gr.FileExplorer(
label="Load Config",
file_count="single",
root=cmd_opts.configs_path if cmd_opts.configs_path else get_configs_path(),
height=75,
)
load_sd_config.change(
fn=load_sd_cfg,
inputs=[sd_json, load_sd_config],
outputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
ondemand,
repeatable_seeds,
resample_type,
cnet_config,
embeddings_config,
sd_json,
],
queue=True,
show_progress=False,
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
queue=True,
show_progress=False,
)
kwargs = dict(
fn=shark_sd_fn,
pull_kwargs = dict(
fn=pull_sd_configs,
inputs=[
prompt,
negative_prompt,
@@ -654,23 +757,19 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
batch_count,
batch_size,
scheduler,
sd_base,
sd_custom_weights,
sd_custom_vae,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
sd_json,
],
outputs=[
sd_gallery,
std_output,
sd_status,
sd_json
],
show_progress="minimal",
)
status_kwargs = dict(
@@ -679,11 +778,23 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
outputs=sd_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
gen_kwargs = dict(
fn=shark_sd_fn_dict_input,
inputs=[sd_json],
outputs=[
sd_gallery,
std_output,
sd_status
],
queue=True,
show_progress="minimal"
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**pull_kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**pull_kwargs).then(**gen_kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],

View File

@@ -1,7 +1,7 @@
import os
import sys
import glob
import datetime as dt
from datetime import datetime as dt
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
@@ -11,6 +11,9 @@ checkpoints_filetypes = (
"*.safetensors",
)
def safe_name(name):
return name.replace("/", "_").replace("-", "_")
def get_path_stem(path):
path = Path(path)
return path.stem

View File

@@ -9,10 +9,12 @@ Also we could avoid memory leak when switching models by clearing the cache.
def _init():
global _sd_obj
global _config_obj
global _pipe_kwargs
global _gen_kwargs
global _schedulers
_sd_obj = None
_config_obj = None
_pipe_kwargs = None
_gen_kwargs = None
_schedulers = None
@@ -31,9 +33,14 @@ def set_sd_status(value):
_sd_obj.status = value
def set_cfg_obj(value):
global _config_obj
_config_obj = value
def set_pipe_kwargs(value):
global _pipe_kwargs
_pipe_kwargs = value
def set_gen_kwargs(value):
global _gen_kwargs
_gen_kwargs = value
def set_schedulers(value):
@@ -51,9 +58,14 @@ def get_sd_status():
return _sd_obj.status
def get_cfg_obj():
global _config_obj
return _config_obj
def get_pipe_kwargs():
global _pipe_kwargs
return _pipe_kwargs
def get_gen_kwargs():
global _gen_kwargs
return _gen_kwargs
def get_scheduler(key):
@@ -63,12 +75,15 @@ def get_scheduler(key):
def clear_cache():
global _sd_obj
global _config_obj
global _pipe_kwargs
global _gen_kwargs
global _schedulers
del _sd_obj
del _config_obj
del _pipe_kwargs
del _gen_kwargs
del _schedulers
gc.collect()
_sd_obj = None
_config_obj = None
_pipe_kwargs = None
_gen_kwargs = None
_schedulers = None