Include VAE & LoRA data into PNG metadata (#1573)

* include custom lora and vae data in png metadata

* include pycharm settings

* lint with black
This commit is contained in:
Nelson Sharpe
2023-06-22 19:05:54 -04:00
committed by GitHub
parent 8822b9acd7
commit 44a8f2f8db
4 changed files with 124 additions and 31 deletions

2
.gitignore vendored
View File

@@ -159,7 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# vscode related
.vscode

View File

@@ -757,6 +757,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
if args.ckpt_loc:
img_model = Path(os.path.basename(args.ckpt_loc)).stem
img_vae = None
if args.custom_vae:
img_vae = Path(os.path.basename(args.custom_vae)).stem
img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
@@ -767,7 +775,9 @@ def save_output_img(output_img, img_seed, extra_info={}):
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps: {args.steps},"
f"Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed},"
f"Size: {args.width}x{args.height}, Model: {img_model}, VAE: {img_vae}, LoRA: {img_lora}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
@@ -778,6 +788,9 @@ def save_output_img(output_img, img_seed, extra_info={}):
"Image saved as png instead. Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
@@ -791,6 +804,8 @@ def save_output_img(output_img, img_seed, extra_info={}):
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
"LORA": img_lora,
}
new_entry.update(extra_info)

View File

@@ -553,6 +553,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
txt2img_png_info_img,
@@ -566,5 +569,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
)

View File

@@ -62,6 +62,82 @@ def parse_generation_parameters(x: str):
return res
def try_find_model_base_from_png_metadata(
file: str, folder: str = "models"
) -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(
get_custom_model_pathfile(file + ".safetensors", folder)
).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in predefined_models:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
@@ -74,40 +150,21 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
png_hf_model_id = ""
png_custom_model = ""
if "Model" in metadata:
# Remove extension from model info
if metadata["Model"].endswith(".safetensors") or metadata[
"Model"
].endswith(".ckpt"):
metadata["Model"] = Path(metadata["Model"]).stem
# Check for the model name match with one of the local ckpt or safetensors files
if Path(
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
).is_file():
png_custom_model = metadata["Model"] + ".ckpt"
if Path(
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
).is_file():
png_custom_model = metadata["Model"] + ".safetensors"
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if metadata["Model"] in predefined_models:
png_custom_model = metadata["Model"]
# If nothing had matched, check vendor/hf_model_id
if not png_custom_model and metadata["Model"].count("/"):
png_hf_model_id = metadata["Model"]
# No matching model was found
if not png_custom_model and not png_hf_model_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% metadata["Model"]
)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
@@ -115,12 +172,24 @@ def import_png_metadata(
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
custom_model = "None"
hf_model_id = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
@@ -149,4 +218,7 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
)