mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Fixing json API
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from turbine_models.custom_models.sd_inference import clip, unet, vae
|
||||
from apps.shark_studio.api.controlnet import control_adapter_map
|
||||
from apps.shark_studio.web.utils.state import status_label
|
||||
from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path
|
||||
from apps.shark_studio.modules.pipeline import SharkPipelineBase
|
||||
from apps.shark_studio.modules.img_processing import resize_stencil, save_output_img
|
||||
from math import ceil
|
||||
@@ -9,6 +10,8 @@ import torch
|
||||
import gradio as gr
|
||||
import PIL
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
|
||||
sd_model_map = {
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
@@ -90,6 +93,24 @@ sd_model_map = {
|
||||
}
|
||||
|
||||
|
||||
def get_spec(custom_sd_map: dict, sd_embeds: dict):
|
||||
spec = []
|
||||
for key in custom_sd_map:
|
||||
if "control" in key.split("_"):
|
||||
spec.append("controlled")
|
||||
elif key == "custom_vae":
|
||||
spec.append(custom_sd_map[key]["custom_weights"].split(".")[0])
|
||||
num_embeds = 0
|
||||
embeddings_spec = None
|
||||
for embed in sd_embeds:
|
||||
if embed is not None:
|
||||
num_embeds += 1
|
||||
embeddings_spec = num_embeds + "embeds"
|
||||
if embeddings_spec:
|
||||
spec.append(embeddings_spec)
|
||||
return "_".join(spec)
|
||||
|
||||
|
||||
class StableDiffusion(SharkPipelineBase):
|
||||
# This class is responsible for executing image generation and creating
|
||||
# /managing a set of compiled modules to run Stable Diffusion. The init
|
||||
@@ -113,15 +134,17 @@ class StableDiffusion(SharkPipelineBase):
|
||||
custom_model_map: dict = {},
|
||||
embeddings: dict = {},
|
||||
import_ir: bool = True,
|
||||
is_img2img: bool = False,
|
||||
):
|
||||
super().__init__(sd_model_map[base_model_id], device, import_ir)
|
||||
self.base_model_id = base_model_id
|
||||
self.device = device
|
||||
super().__init__(sd_model_map[base_model_id], base_model_id, device, import_ir)
|
||||
self.precision = precision
|
||||
self.iree_module_dict = None
|
||||
self.get_compiled_map()
|
||||
|
||||
def prepare_pipeline(self, scheduler, custom_model_map):
|
||||
self.is_img2img = is_img2img
|
||||
self.pipe_id = safe_name(base_model_id) + str(height) + str(width) + precision + device + get_spec(custom_model_map, embeddings)
|
||||
|
||||
|
||||
def prepare_pipe(self, scheduler, custom_model_map, embeddings):
|
||||
print(f"Preparing pipeline with scheduler {scheduler}, custom map {json.dumps(custom_model_map)}, and embeddings {json.dumps(embeddings)}.")
|
||||
self.get_compiled_map(device=self.device, pipe_id=self.pipe_id)
|
||||
return None
|
||||
|
||||
def generate_images(
|
||||
@@ -136,26 +159,38 @@ class StableDiffusion(SharkPipelineBase):
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints,
|
||||
hints,
|
||||
):
|
||||
return None, None, None, None, None
|
||||
print("Generating Images...")
|
||||
test_img = [PIL.Image.open(get_resource_path("../../tests/jupiter.png"), mode="r")]
|
||||
return test_img#, "", ""
|
||||
|
||||
|
||||
# NOTE: Each `hf_model_id` should have its own starting configuration.
|
||||
|
||||
# model_vmfb_key = ""
|
||||
def shark_sd_fn_dict_input(
|
||||
sd_kwargs: dict,
|
||||
):
|
||||
input_imgs=[]
|
||||
img_paths = sd_kwargs["sd_init_image"]
|
||||
for img_path in img_paths:
|
||||
if os.path.isfile(img_path):
|
||||
input_imgs.append(PIL.Image.open(img_path, mode='r').convert("RGB"))
|
||||
sd_kwargs["sd_init_image"] = input_imgs
|
||||
generated_imgs = shark_sd_fn(
|
||||
**sd_kwargs
|
||||
)
|
||||
return generated_imgs, "OK", "OK"
|
||||
|
||||
|
||||
def shark_sd_fn(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image_dict,
|
||||
sd_init_image,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seeds: list,
|
||||
seed: list,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -164,23 +199,17 @@ def shark_sd_fn(
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
lora_weights: str | list,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
control_mode: str,
|
||||
sd_json: dict,
|
||||
progress=gr.Progress(),
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
):
|
||||
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
|
||||
stencils=[]
|
||||
preprocessed_hints=[]
|
||||
cnet_strengths=[]
|
||||
|
||||
if isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
elif image_dict:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
sd_kwargs = locals()
|
||||
if isinstance(sd_init_image, PIL.Image.Image):
|
||||
image = sd_init_image.convert("RGB")
|
||||
elif sd_init_image:
|
||||
image = sd_init_image["image"].convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
@@ -197,20 +226,33 @@ def shark_sd_fn(
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
custom_model_map = {}
|
||||
control_mode = None
|
||||
hints = []
|
||||
if custom_weights != "None":
|
||||
custom_model_map["unet"] = {"custom_weights": custom_weights}
|
||||
if custom_vae != "None":
|
||||
custom_model_map["vae"] = {"custom_weights": custom_vae}
|
||||
if stencils:
|
||||
for i, stencil in enumerate(stencils):
|
||||
if "model" in controlnets:
|
||||
for i, model in enumerate(controlnets["model"]):
|
||||
if "xl" not in base_model_id.lower():
|
||||
custom_model_map[f"control_adapter_{i}"] = control_adapter_map[
|
||||
"runwayml/stable-diffusion-v1-5"
|
||||
][stencil]
|
||||
custom_model_map[f"control_adapter_{model}"] = {
|
||||
"hf_id": control_adapter_map[
|
||||
"runwayml/stable-diffusion-v1-5"
|
||||
][model],
|
||||
"strength": controlnets["strength"][i],
|
||||
}
|
||||
else:
|
||||
custom_model_map[f"control_adapter_{i}"] = control_adapter_map[
|
||||
"stabilityai/stable-diffusion-xl-1.0"
|
||||
][stencil]
|
||||
custom_model_map[f"control_adapter_{model}"] = {
|
||||
"hf_id": control_adapter_map[
|
||||
"stabilityai/stable-diffusion-xl-1.0"
|
||||
][model],
|
||||
"strength": controlnets["strength"][i]
|
||||
}
|
||||
control_mode = controlnets["control_mode"]
|
||||
for i in controlnets["hint"]:
|
||||
hints.append[i]
|
||||
|
||||
print(json.dumps(custom_model_map))
|
||||
|
||||
submit_pipe_kwargs = {
|
||||
"base_model_id": base_model_id,
|
||||
@@ -219,13 +261,14 @@ def shark_sd_fn(
|
||||
"precision": precision,
|
||||
"device": device,
|
||||
"custom_model_map": custom_model_map,
|
||||
"embeddings": embeddings,
|
||||
"import_ir": cmd_opts.import_mlir,
|
||||
"is_img2img": is_img2img,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
"scheduler": scheduler,
|
||||
"custom_model_map": custom_model_map,
|
||||
"embeddings": lora_weights,
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
submit_run_kwargs = {
|
||||
"prompt": prompt,
|
||||
@@ -233,18 +276,18 @@ def shark_sd_fn(
|
||||
"steps": steps,
|
||||
"strength": strength,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seeds": seeds,
|
||||
"seed": seed,
|
||||
"ondemand": ondemand,
|
||||
"repeatable_seeds": repeatable_seeds,
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"preprocessed_hints": preprocessed_hints,
|
||||
"hints": hints,
|
||||
}
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
||||
):
|
||||
print("Regenerating pipeline...")
|
||||
print("Initializing new pipeline...")
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
|
||||
@@ -258,27 +301,16 @@ def shark_sd_fn(
|
||||
)
|
||||
global_obj.set_sd_obj(sd_pipe)
|
||||
|
||||
sd_pipe.prepare_pipe(**submit_prep_kwargs)
|
||||
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
|
||||
generated_imgs = []
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
ceil(steps / strength),
|
||||
strength,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
stencils,
|
||||
resample_type=resample_type,
|
||||
control_mode=control_mode,
|
||||
preprocessed_hints=preprocessed_hints,
|
||||
**submit_run_kwargs
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = []
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += "\n" # + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
@@ -286,17 +318,15 @@ def shark_sd_fn(
|
||||
# else:
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
seeds[current_batch],
|
||||
sd_json,
|
||||
seed[current_batch],
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Image-to-Image", current_batch + 1, batch_count, batch_size
|
||||
), stencils
|
||||
yield generated_imgs#, text_output, status_label(
|
||||
#"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
#)
|
||||
|
||||
return generated_imgs, text_output, "", stencil, image
|
||||
|
||||
return generated_imgs, text_output, "", stencil, image
|
||||
return generated_imgs#, text_output, ""
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
import re
|
||||
import datetime as dt
|
||||
import json
|
||||
from csv import DictWriter
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
from csv import DictWriter
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
from base64 import decode
|
||||
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
@@ -23,7 +24,7 @@ resampler_list = resamplers.keys()
|
||||
def save_output_img(output_img, img_seed, extra_info=None):
|
||||
|
||||
from apps.shark_studio.web.utils.file_utils import get_generated_imgs_path, get_generated_imgs_todays_subdir
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
@@ -33,60 +34,63 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompts"][0][:15])
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15])
|
||||
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
|
||||
|
||||
#img_model = cmd_opts.hf_model_id
|
||||
#if cmd_opts.ckpt_loc:
|
||||
# img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem
|
||||
img_model = extra_info["base_model_id"]
|
||||
if extra_info["custom_weights"] not in [None, "None"]:
|
||||
img_model = Path(os.path.basename(extra_info["custom_weights"])).stem
|
||||
|
||||
#img_vae = None
|
||||
#if cmd_opts.custom_vae:
|
||||
# img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem
|
||||
img_vae = None
|
||||
if extra_info["custom_vae"]:
|
||||
img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem
|
||||
|
||||
#img_lora = None
|
||||
#if cmd_opts.use_lora:
|
||||
# img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem
|
||||
img_loras = None
|
||||
if extra_info["embeddings"]:
|
||||
img_lora = []
|
||||
for i in extra_info["embeddings"]:
|
||||
img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem
|
||||
img_loras = ", ".join(img_lora)
|
||||
|
||||
#if cmd_opts.output_img_format == "jpg":
|
||||
# out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
# output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
#else:
|
||||
# out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
# pngInfo = PngImagePlugin.PngInfo()
|
||||
if cmd_opts.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
# if cmd_opts.write_metadata_to_png:
|
||||
# # Using a conditional expression caused problems, so setting a new
|
||||
# # variable for now.
|
||||
# if cmd_opts.use_hiresfix:
|
||||
# png_size_text = (
|
||||
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
|
||||
# )
|
||||
# else:
|
||||
# png_size_text = f"{cmd_opts.width}x{cmd_opts.height}"
|
||||
if cmd_opts.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
#if cmd_opts.use_hiresfix:
|
||||
# png_size_text = (
|
||||
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
|
||||
# )
|
||||
#else:
|
||||
png_size_text = f"{extra_info['width']}x{extra_info['height']}"
|
||||
|
||||
# pngInfo.add_text(
|
||||
# "parameters",
|
||||
# f"{cmd_opts.prompts[0]}"
|
||||
# f"\nNegative prompt: {cmd_opts.negative_prompts[0]}"
|
||||
# f"\nSteps: {cmd_opts.steps},"
|
||||
# f"Sampler: {cmd_opts.scheduler}, "
|
||||
# f"CFG scale: {cmd_opts.guidance_scale}, "
|
||||
# f"Seed: {img_seed},"
|
||||
# f"Size: {png_size_text}, "
|
||||
# f"Model: {img_model}, "
|
||||
# f"VAE: {img_vae}, "
|
||||
# f"LoRA: {img_lora}",
|
||||
# )
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{extra_info['prompt'][0]}"
|
||||
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
|
||||
f"\nSteps: {extra_info['steps'][0]},"
|
||||
f"Sampler: {extra_info['scheduler'][0]}, "
|
||||
f"CFG scale: {extra_info['guidance_scale'][0]}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_loras}",
|
||||
)
|
||||
|
||||
# output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
# if cmd_opts.output_img_format not in ["png", "jpg"]:
|
||||
# print(
|
||||
# f"[ERROR] Format {cmd_opts.output_img_format} is not "
|
||||
# f"supported yet. Image saved as png instead."
|
||||
# f"Supported formats: png / jpg"
|
||||
# )
|
||||
if cmd_opts.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {cmd_opts.output_img_format} is not "
|
||||
f"supported yet. Image saved as png instead."
|
||||
f"Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
||||
@@ -104,7 +108,6 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module
|
||||
from apps.shark_studio.web.utils.file_utils import get_checkpoints_path
|
||||
import gc
|
||||
|
||||
|
||||
class SharkPipelineBase:
|
||||
@@ -12,33 +14,37 @@ class SharkPipelineBase:
|
||||
def __init__(
|
||||
self,
|
||||
model_map: dict,
|
||||
base_model_id: str,
|
||||
device: str,
|
||||
import_mlir: bool = True,
|
||||
):
|
||||
self.model_map = model_map
|
||||
self.base_model_id = base_model_id
|
||||
self.device = device
|
||||
self.import_mlir = import_mlir
|
||||
self.iree_module_dict = {}
|
||||
|
||||
|
||||
def import_torch_ir(self, submodel, kwargs):
|
||||
weights = (
|
||||
submodel["custom_weights"]
|
||||
if submodel["custom_weights"]
|
||||
else None
|
||||
)
|
||||
torch_ir = self.model_map[submodel]["initializer"](
|
||||
self.base_model_id, **kwargs, compile_to="torch"
|
||||
)
|
||||
self.model_map[submodel]["tempfile_name"] = get_resource_path(
|
||||
f"{submodel}.torch.tempfile"
|
||||
)
|
||||
with open(self.model_map[submodel]["tempfile_name"], "w+") as f:
|
||||
f.write(torch_ir)
|
||||
del torch_ir
|
||||
gc.collect()
|
||||
|
||||
def import_torch_ir(self, base_model_id):
|
||||
for submodel in self.model_map:
|
||||
hf_id = (
|
||||
submodel["custom_hf_id"]
|
||||
if submodel["custom_hf_id"]
|
||||
else base_model_id
|
||||
)
|
||||
torch_ir = submodel["initializer"](
|
||||
hf_id, **submodel["init_kwargs"], compile_to="torch"
|
||||
)
|
||||
submodel["tempfile_name"] = get_resource_path(
|
||||
f"{submodel}.torch.tempfile"
|
||||
)
|
||||
with open(submodel["tempfile_name"], "w+") as f:
|
||||
f.write(torch_ir)
|
||||
del torch_ir
|
||||
gc.collect()
|
||||
|
||||
def load_vmfb(self, submodel):
|
||||
if self.iree_module_dict[submodel]:
|
||||
if submodel in self.iree_module_dict:
|
||||
print(
|
||||
f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}"
|
||||
)
|
||||
@@ -47,25 +53,51 @@ class SharkPipelineBase:
|
||||
|
||||
return submodel["vmfb"]
|
||||
|
||||
|
||||
def merge_custom_map(self, custom_model_map):
|
||||
for submodel in custom_model_map:
|
||||
for key in submodel:
|
||||
self.model_map[submodel][key] = key
|
||||
print(self.model_map)
|
||||
|
||||
def get_compiled_map(self, device) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
|
||||
def get_local_vmfbs(self, pipe_id):
|
||||
for submodel in self.model_map:
|
||||
if not self.iree_module_dict[submodel][vmfb]:
|
||||
vmfbs = []
|
||||
vmfb_matches = {}
|
||||
vmfbs_path = get_checkpoints_path("../vmfbs")
|
||||
for (dirpath, dirnames, filenames) in os.walk(vmfbs_path):
|
||||
vmfbs.extend(filenames)
|
||||
break
|
||||
for file in vmfbs:
|
||||
if all(keys in file for keys in [submodel, pipe_id]):
|
||||
print(f"Found existing .vmfb at {file}")
|
||||
self.iree_module_dict[submodel] = {'vmfb': file}
|
||||
|
||||
|
||||
def get_compiled_map(self, device, pipe_id) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
if not self.import_mlir:
|
||||
self.get_local_vmfbs(pipe_id)
|
||||
for submodel in self.model_map:
|
||||
if submodel in self.iree_module_dict:
|
||||
if "vmfb" in self.iree_module_dict[submodel]:
|
||||
continue
|
||||
if "tempfile_name" not in self.model_map[submodel]:
|
||||
sub_kwargs = self.model_map[submodel]["kwargs"] if self.model_map[submodel]["kwargs"] else {}
|
||||
import_torch_ir(submodel, self.base_model_id, **sub_kwargs)
|
||||
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
||||
submodel.tempfile_name,
|
||||
submodel["tempfile_name"],
|
||||
device=self.device,
|
||||
frontend="torch",
|
||||
external_weight_file=submodel["custom_weights"]
|
||||
)
|
||||
# TODO: delete the temp file
|
||||
|
||||
|
||||
def run(self, submodel, inputs):
|
||||
return
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("-", "_")
|
||||
|
||||
BIN
apps/shark_studio/tests/jupiter.png
Normal file
BIN
apps/shark_studio/tests/jupiter.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 347 KiB |
24
apps/shark_studio/web/configs/default_sd_config.json
Normal file
24
apps/shark_studio/web/configs/default_sd_config.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"prompt": [ "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" ],
|
||||
"negative_prompt": [ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" ],
|
||||
"sd_init_image": [ "None" ],
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"steps": [ 50 ],
|
||||
"strength": [ 0.8 ],
|
||||
"guidance_scale": [ 7.5 ],
|
||||
"seed": [ -1 ],
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": [ "EulerDiscrete" ],
|
||||
"base_model_id": "runwayml/stable-diffusion-v1-5",
|
||||
"custom_weights": "",
|
||||
"custom_vae": "",
|
||||
"precision": "fp16",
|
||||
"device": "vulkan",
|
||||
"ondemand": "False",
|
||||
"repeatable_seeds": "False",
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
{}
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
import json
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
from math import ceil
|
||||
from inspect import signature
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
@@ -24,7 +22,7 @@ from apps.shark_studio.web.utils.file_utils import (
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
sd_model_map,
|
||||
shark_sd_fn,
|
||||
shark_sd_fn_dict_input,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.shark_studio.api.controlnet import (
|
||||
@@ -47,18 +45,19 @@ from apps.shark_studio.web.utils.state import (
|
||||
from apps.shark_studio.web.ui.common_events import lora_changed
|
||||
|
||||
|
||||
def view_json_file(file_obj):
|
||||
def view_json_file(file_path):
|
||||
content = ""
|
||||
with open(file_obj.name, "r") as fopen:
|
||||
with open(file_path, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
def submit_to_cnet_config(stencil: str, preprocessed_hint: str, cnet_strength: int, curr_config: dict):
|
||||
def submit_to_cnet_config(stencil: str, preprocessed_hint: str, cnet_strength: int, control_mode: str, curr_config: dict):
|
||||
if any(i in [None, ""] for i in [stencil, preprocessed_hint]):
|
||||
return gr.update()
|
||||
if curr_config is not None:
|
||||
if "controlnets" in curr_config:
|
||||
if "controlnets" in curr_config:
|
||||
curr_config["controlnets"]["control_mode"] = control_mode
|
||||
curr_config["controlnets"]["model"].append(stencil)
|
||||
curr_config["controlnets"]["hint"].append(preprocessed_hint)
|
||||
curr_config["controlnets"]["strength"].append(cnet_strength)
|
||||
@@ -66,6 +65,7 @@ def submit_to_cnet_config(stencil: str, preprocessed_hint: str, cnet_strength: i
|
||||
|
||||
cnet_map = {}
|
||||
cnet_map["controlnets"] = {
|
||||
"control_mode": control_mode,
|
||||
"model": [stencil],
|
||||
"hint": [preprocessed_hint],
|
||||
"strength": [cnet_strength],
|
||||
@@ -85,8 +85,7 @@ def update_embeddings_json(embedding, curr_config: dict):
|
||||
|
||||
|
||||
def submit_to_main_config(input_cfg: dict, main_cfg: dict):
|
||||
if main_cfg in [None, ""]:
|
||||
# only time main_cfg should be a string is empty case.
|
||||
if main_cfg in [None, "", {}]:
|
||||
return input_cfg
|
||||
|
||||
for base_key in input_cfg:
|
||||
@@ -94,6 +93,75 @@ def submit_to_main_config(input_cfg: dict, main_cfg: dict):
|
||||
return main_cfg
|
||||
|
||||
|
||||
def pull_sd_configs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
sd_json,
|
||||
):
|
||||
sd_args = locals()
|
||||
for arg in sd_args:
|
||||
sd_json[arg] = sd_args[arg]
|
||||
|
||||
return sd_json
|
||||
|
||||
|
||||
def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
new_sd_config = json.loads(view_json_file(load_sd_config))
|
||||
if sd_json:
|
||||
for key in new_sd_config:
|
||||
sd_json[key] = new_sd_config[key]
|
||||
else:
|
||||
sd_json = new_sd_config
|
||||
if os.path.isfile(sd_json["sd_init_image"][0]):
|
||||
sd_image = Image.open(sd_json["sd_init_image"][0], mode='r')
|
||||
else:
|
||||
sd_image = None
|
||||
|
||||
return [
|
||||
sd_json["prompt"][0],
|
||||
sd_json["negative_prompt"][0],
|
||||
sd_image,
|
||||
sd_json["height"],
|
||||
sd_json["width"],
|
||||
sd_json["steps"][0],
|
||||
sd_json["strength"][0],
|
||||
sd_json["guidance_scale"],
|
||||
sd_json["seed"][0],
|
||||
sd_json["batch_count"],
|
||||
sd_json["batch_size"],
|
||||
sd_json["scheduler"][0],
|
||||
sd_json["base_model_id"],
|
||||
sd_json["custom_weights"],
|
||||
sd_json["custom_vae"],
|
||||
sd_json["precision"],
|
||||
sd_json["device"],
|
||||
sd_json["ondemand"],
|
||||
sd_json["repeatable_seeds"],
|
||||
sd_json["resample_type"],
|
||||
sd_json["controlnets"],
|
||||
sd_json["embeddings"],
|
||||
sd_json,
|
||||
]
|
||||
|
||||
|
||||
def save_sd_cfg(config: dict, save_name: str):
|
||||
if os.path.exists(save_name):
|
||||
filepath=save_name
|
||||
@@ -213,15 +281,7 @@ def update_cn_input(
|
||||
]
|
||||
|
||||
|
||||
sd_fn_inputs = []
|
||||
sd_fn_sig = signature(shark_sd_fn).replace()
|
||||
for i in sd_fn_sig.parameters:
|
||||
sd_fn_inputs.append(i)
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
# Get a list of arguments needed for the API call, then
|
||||
# initialize an empty list that will manage the corresponding
|
||||
# gradio values.
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row(variant="compact", equal_height=True):
|
||||
@@ -239,34 +299,34 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Column(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Row(variant="compact"):
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=3):
|
||||
sd_model_info = (
|
||||
f"Checkpoint Path: {str(get_checkpoints_path())}"
|
||||
)
|
||||
sd_base = gr.Dropdown(
|
||||
base_model_id = gr.Dropdown(
|
||||
label="Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2-1-base",
|
||||
choices=sd_model_map.keys(),
|
||||
) # base_model_id
|
||||
sd_custom_weights = gr.Dropdown(
|
||||
custom_weights = gr.Dropdown(
|
||||
label="Custom Weights",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
allow_custom_value=True,
|
||||
choices=get_checkpoints(sd_base),
|
||||
choices=get_checkpoints(base_model_id),
|
||||
) #
|
||||
with gr.Column(scale=2):
|
||||
sd_vae_info = (
|
||||
str(get_checkpoints_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
sd_vae_info = f"VAE Path: {sd_vae_info}"
|
||||
sd_custom_vae = gr.Dropdown(
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom VAE Models",
|
||||
info=sd_vae_info,
|
||||
elem_id="custom_model",
|
||||
@@ -513,7 +573,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
visible=True,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
show_label=False,
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
)
|
||||
@@ -558,6 +618,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
[
|
||||
cnet_input,
|
||||
],
|
||||
queue=True,
|
||||
show_progress=False,
|
||||
)
|
||||
gr.on(
|
||||
triggers=[cnet_gen.click],
|
||||
@@ -573,6 +635,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
],
|
||||
queue=True,
|
||||
show_progress=False,
|
||||
)
|
||||
use_result.click(
|
||||
fn=submit_to_cnet_config,
|
||||
@@ -580,11 +644,14 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
stencil,
|
||||
preprocessed_hint,
|
||||
cnet_strength,
|
||||
control_mode,
|
||||
cnet_config,
|
||||
],
|
||||
outputs=[
|
||||
cnet_config,
|
||||
]
|
||||
],
|
||||
queue=True,
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -615,32 +682,68 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Group():
|
||||
sd_json = gr.JSON()
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
sd_json = gr.JSON(value=view_json_file(os.path.join(get_configs_path(), "default_sd_config.json")))
|
||||
with gr.Column(scale=1):
|
||||
clear_sd_config = gr.ClearButton(
|
||||
value="Clear Config", size="sm"
|
||||
)
|
||||
save_sd_config = gr.Button(
|
||||
value="Save Config", size="sm"
|
||||
)
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
value="Clear Config", size="sm", components=sd_json
|
||||
)
|
||||
with gr.Row():
|
||||
save_sd_config = gr.Button(
|
||||
value="Save Config", size="sm"
|
||||
)
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
)
|
||||
load_sd_config = gr.FileExplorer(
|
||||
label="Load Config",
|
||||
file_count="single",
|
||||
root=cmd_opts.configs_path if cmd_opts.configs_path else get_configs_path(),
|
||||
height=75,
|
||||
)
|
||||
load_sd_config.change(
|
||||
fn=load_sd_cfg,
|
||||
inputs=[sd_json, load_sd_config],
|
||||
outputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
sd_json,
|
||||
],
|
||||
queue=True,
|
||||
show_progress=False,
|
||||
)
|
||||
save_sd_config.click(
|
||||
fn=save_sd_cfg,
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
queue=True,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=shark_sd_fn,
|
||||
pull_kwargs = dict(
|
||||
fn=pull_sd_configs,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -654,23 +757,19 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
sd_base,
|
||||
sd_custom_weights,
|
||||
sd_custom_vae,
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
sd_json,
|
||||
],
|
||||
outputs=[
|
||||
sd_gallery,
|
||||
std_output,
|
||||
sd_status,
|
||||
sd_json
|
||||
],
|
||||
show_progress="minimal",
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
@@ -679,11 +778,23 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
outputs=sd_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
gen_kwargs = dict(
|
||||
fn=shark_sd_fn_dict_input,
|
||||
inputs=[sd_json],
|
||||
outputs=[
|
||||
sd_gallery,
|
||||
std_output,
|
||||
sd_status
|
||||
],
|
||||
queue=True,
|
||||
show_progress="minimal"
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**pull_kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**pull_kwargs).then(**gen_kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import datetime as dt
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
@@ -11,6 +11,9 @@ checkpoints_filetypes = (
|
||||
"*.safetensors",
|
||||
)
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("-", "_")
|
||||
|
||||
def get_path_stem(path):
|
||||
path = Path(path)
|
||||
return path.stem
|
||||
|
||||
@@ -9,10 +9,12 @@ Also we could avoid memory leak when switching models by clearing the cache.
|
||||
|
||||
def _init():
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
global _pipe_kwargs
|
||||
global _gen_kwargs
|
||||
global _schedulers
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
_pipe_kwargs = None
|
||||
_gen_kwargs = None
|
||||
_schedulers = None
|
||||
|
||||
|
||||
@@ -31,9 +33,14 @@ def set_sd_status(value):
|
||||
_sd_obj.status = value
|
||||
|
||||
|
||||
def set_cfg_obj(value):
|
||||
global _config_obj
|
||||
_config_obj = value
|
||||
def set_pipe_kwargs(value):
|
||||
global _pipe_kwargs
|
||||
_pipe_kwargs = value
|
||||
|
||||
|
||||
def set_gen_kwargs(value):
|
||||
global _gen_kwargs
|
||||
_gen_kwargs = value
|
||||
|
||||
|
||||
def set_schedulers(value):
|
||||
@@ -51,9 +58,14 @@ def get_sd_status():
|
||||
return _sd_obj.status
|
||||
|
||||
|
||||
def get_cfg_obj():
|
||||
global _config_obj
|
||||
return _config_obj
|
||||
def get_pipe_kwargs():
|
||||
global _pipe_kwargs
|
||||
return _pipe_kwargs
|
||||
|
||||
|
||||
def get_gen_kwargs():
|
||||
global _gen_kwargs
|
||||
return _gen_kwargs
|
||||
|
||||
|
||||
def get_scheduler(key):
|
||||
@@ -63,12 +75,15 @@ def get_scheduler(key):
|
||||
|
||||
def clear_cache():
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
global _pipe_kwargs
|
||||
global _gen_kwargs
|
||||
global _schedulers
|
||||
del _sd_obj
|
||||
del _config_obj
|
||||
del _pipe_kwargs
|
||||
del _gen_kwargs
|
||||
del _schedulers
|
||||
gc.collect()
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
_pipe_kwargs = None
|
||||
_gen_kwargs = None
|
||||
_schedulers = None
|
||||
|
||||
Reference in New Issue
Block a user