Added flags for metadata information. (#894)

This commit is contained in:
Evan Guan
2023-02-01 05:16:11 -08:00
committed by GitHub
parent 3eceeb7b23
commit 8cafe56eb4
7 changed files with 133 additions and 9 deletions

View File

@@ -2,10 +2,12 @@ import os
os.environ["AMD_ENABLE_LLPC"] = "1"
import json
import torch
import re
import time
from pathlib import Path
from PIL import PngImagePlugin
from datetime import datetime as dt
from dataclasses import dataclass
from csv import DictWriter
@@ -61,7 +63,29 @@ def save_output_img(output_img):
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.hf_model_id}",
)
output_img.save(
output_path / f"{out_img_name}.png", "PNG", pnginfo=pngInfo
)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
)
new_entry = {
"VARIANT": args.hf_model_id,
@@ -83,6 +107,11 @@ def save_output_img(output_img):
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
del new_entry["OUTPUT"]
with open(f"{output_path}/{out_img_name}.json", "w") as f:
json.dump(new_entry, f, indent=4)
txt2img_obj = None
config_obj = None
@@ -106,6 +135,8 @@ def txt2img_inf(
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
global txt2img_obj
global config_obj
@@ -119,6 +150,8 @@ def txt2img_inf(
args.scheduler = scheduler
args.hf_model_id = custom_model_id if custom_model_id else model_id
args.ckpt_loc = ckpt_file_obj.name if ckpt_file_obj else ""
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(

View File

@@ -270,6 +270,20 @@ p.add_argument(
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
)
##############################################################################
### Web UI flags
##############################################################################

View File

@@ -148,6 +148,17 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
step=0.1,
label="CFG Scale",
)
with gr.Row():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=False,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=False,
interactive=True,
)
with gr.Row():
seed = gr.Number(value=-1, precision=0, label="Seed")
available_devices = get_available_devices()
@@ -211,6 +222,8 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
@@ -233,6 +246,8 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,