mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
203 lines
7.0 KiB
Python
203 lines
7.0 KiB
Python
import os
|
|
import re
|
|
import json
|
|
import torch
|
|
import numpy as np
|
|
|
|
from csv import DictWriter
|
|
from PIL import Image, PngImagePlugin
|
|
from pathlib import Path
|
|
from datetime import datetime as dt
|
|
from base64 import decode
|
|
|
|
|
|
resamplers = {
|
|
"Lanczos": Image.Resampling.LANCZOS,
|
|
"Nearest Neighbor": Image.Resampling.NEAREST,
|
|
"Bilinear": Image.Resampling.BILINEAR,
|
|
"Bicubic": Image.Resampling.BICUBIC,
|
|
"Hamming": Image.Resampling.HAMMING,
|
|
"Box": Image.Resampling.BOX,
|
|
}
|
|
|
|
resampler_list = resamplers.keys()
|
|
|
|
|
|
# save output images and the inputs corresponding to it.
|
|
def save_output_img(output_img, img_seed, extra_info=None):
|
|
from apps.amdshark_studio.web.utils.file_utils import (
|
|
get_generated_imgs_path,
|
|
get_generated_imgs_todays_subdir,
|
|
)
|
|
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
|
|
|
if extra_info is None:
|
|
extra_info = {}
|
|
generated_imgs_path = Path(
|
|
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
|
)
|
|
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
|
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
|
|
|
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15])
|
|
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
|
|
|
|
img_model = extra_info["base_model_id"]
|
|
if extra_info["custom_weights"] not in [None, "None"]:
|
|
img_model = Path(os.path.basename(extra_info["custom_weights"])).stem
|
|
|
|
img_vae = None
|
|
if extra_info["custom_vae"]:
|
|
img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem
|
|
|
|
img_loras = None
|
|
if extra_info["embeddings"]:
|
|
img_lora = []
|
|
for i in extra_info["embeddings"]:
|
|
img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem
|
|
img_loras = ", ".join(img_lora)
|
|
|
|
if cmd_opts.output_img_format == "jpg":
|
|
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
|
output_img.save(out_img_path, quality=95, subsampling=0)
|
|
else:
|
|
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
|
pngInfo = PngImagePlugin.PngInfo()
|
|
|
|
if cmd_opts.write_metadata_to_png:
|
|
# Using a conditional expression caused problems, so setting a new
|
|
# variable for now.
|
|
# if cmd_opts.use_hiresfix:
|
|
# png_size_text = (
|
|
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
|
|
# )
|
|
# else:
|
|
png_size_text = f"{extra_info['width']}x{extra_info['height']}"
|
|
|
|
pngInfo.add_text(
|
|
"parameters",
|
|
f"{extra_info['prompt'][0]}"
|
|
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
|
|
f"\nSteps: {extra_info['steps']},"
|
|
f"Sampler: {extra_info['scheduler']}, "
|
|
f"CFG scale: {extra_info['guidance_scale']}, "
|
|
f"Seed: {img_seed},"
|
|
f"Size: {png_size_text}, "
|
|
f"Model: {img_model}, "
|
|
f"VAE: {img_vae}, "
|
|
f"LoRA: {img_loras}",
|
|
)
|
|
|
|
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
|
|
|
if cmd_opts.output_img_format not in ["png", "jpg"]:
|
|
print(
|
|
f"[ERROR] Format {cmd_opts.output_img_format} is not "
|
|
f"supported yet. Image saved as png instead."
|
|
f"Supported formats: png / jpg"
|
|
)
|
|
|
|
# To be as low-impact as possible to the existing CSV format, we append
|
|
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
|
# importance for each data point. Something to consider.
|
|
new_entry = {}
|
|
|
|
new_entry.update(extra_info)
|
|
|
|
csv_mode = "a" if os.path.isfile(csv_path) else "w"
|
|
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
|
|
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
|
if csv_mode == "w":
|
|
dictwriter_obj.writeheader()
|
|
dictwriter_obj.writerow(new_entry)
|
|
csv_obj.close()
|
|
|
|
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
|
with open(json_path, "w") as f:
|
|
json.dump(new_entry, f, indent=4)
|
|
|
|
|
|
# For stencil, the input image can be of any size, but we need to ensure that
|
|
# it conforms with our model constraints :-
|
|
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
|
# This utility function performs the transformation on the input image while
|
|
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
|
def resize_stencil(image: Image.Image, width, height, resampler_type=None):
|
|
aspect_ratio = width / height
|
|
min_size = min(width, height)
|
|
if min_size < 128:
|
|
n_size = 128
|
|
if width == min_size:
|
|
width = n_size
|
|
height = n_size / aspect_ratio
|
|
else:
|
|
height = n_size
|
|
width = n_size * aspect_ratio
|
|
width = int(width)
|
|
height = int(height)
|
|
n_width = width // 8
|
|
n_height = height // 8
|
|
n_width *= 8
|
|
n_height *= 8
|
|
|
|
min_size = min(width, height)
|
|
if min_size > 768:
|
|
n_size = 768
|
|
if width == min_size:
|
|
height = n_size
|
|
width = n_size * aspect_ratio
|
|
else:
|
|
width = n_size
|
|
height = n_size / aspect_ratio
|
|
width = int(width)
|
|
height = int(height)
|
|
n_width = width // 8
|
|
n_height = height // 8
|
|
n_width *= 8
|
|
n_height *= 8
|
|
if resampler_type in resamplers:
|
|
resampler = resamplers[resampler_type]
|
|
else:
|
|
resampler = resamplers["Nearest Neighbor"]
|
|
new_image = image.resize((n_width, n_height), resampler=resampler)
|
|
return new_image, n_width, n_height
|
|
|
|
|
|
def process_sd_init_image(self, sd_init_image, resample_type):
|
|
if isinstance(sd_init_image, list):
|
|
images = []
|
|
for img in sd_init_image:
|
|
img, _ = self.process_sd_init_image(img, resample_type)
|
|
images.append(img)
|
|
is_img2img = True
|
|
return images, is_img2img
|
|
if isinstance(sd_init_image, str):
|
|
if os.path.isfile(sd_init_image):
|
|
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
|
|
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
|
|
else:
|
|
image = None
|
|
is_img2img = False
|
|
elif isinstance(sd_init_image, Image.Image):
|
|
image = sd_init_image.convert("RGB")
|
|
elif sd_init_image:
|
|
image = sd_init_image["image"].convert("RGB")
|
|
else:
|
|
image = None
|
|
is_img2img = False
|
|
if image:
|
|
resample_type = (
|
|
resamplers[resample_type]
|
|
if resample_type in resampler_list
|
|
# Fallback to Lanczos
|
|
else Image.Resampling.LANCZOS
|
|
)
|
|
image = image.resize((self.width, self.height), resample=resample_type)
|
|
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
|
image_arr = image_arr / 255.0
|
|
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
|
|
image_arr = 2 * (image_arr - 0.5)
|
|
is_img2img = True
|
|
image = image_arr
|
|
return image, is_img2img
|