mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-14 16:28:01 -05:00
* Generates the seeds for all batch_count batches being run up front rather than generating the seed for a batch just before it is run. * Adds a --repeatable_seeds argument defaulting to False * When repeatable_seeds=True, the first seed for a set of batches will also be used as the rng seed for the subsequent batch seeds in the run. The rng seed is then reset. * When repeatable_seeds=False, batch seeding works as currently. * Update scripts under apps/scripts that support the batch_count argument to also support the repeatable_seeds argument. * UI/Web: Adds a checkbox element on each SD tab after batch count/size for toggling repeatable seeds, and update _inf functions to take this into account. * UI/Web: Moves the Stop buttons out of the Advanced sections and next to Generate to make things not fit quite so badly with the extra UI elements. * UI/Web: Fixes logging to the upscaler output text box not working correctly when running multiple batches.
121 lines
3.5 KiB
Python
121 lines
3.5 KiB
Python
import torch
|
|
import time
|
|
from PIL import Image
|
|
import transformers
|
|
from apps.stable_diffusion.src import (
|
|
args,
|
|
OutpaintPipeline,
|
|
get_schedulers,
|
|
set_init_device_flags,
|
|
utils,
|
|
clear_all,
|
|
save_output_img,
|
|
)
|
|
|
|
|
|
def main():
|
|
if args.clear_all:
|
|
clear_all()
|
|
|
|
if args.img_path is None:
|
|
print("Flag --img_path is required.")
|
|
exit()
|
|
|
|
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
|
cpu_scheduling = not args.scheduler.startswith("Shark")
|
|
set_init_device_flags()
|
|
model_id = (
|
|
args.hf_model_id
|
|
if "inpaint" in args.hf_model_id
|
|
else "stabilityai/stable-diffusion-2-inpainting"
|
|
)
|
|
schedulers = get_schedulers(model_id)
|
|
scheduler_obj = schedulers[args.scheduler]
|
|
seed = args.seed
|
|
image = Image.open(args.img_path)
|
|
|
|
outpaint_obj = OutpaintPipeline.from_pretrained(
|
|
scheduler_obj,
|
|
args.import_mlir,
|
|
args.hf_model_id,
|
|
args.ckpt_loc,
|
|
args.custom_vae,
|
|
args.precision,
|
|
args.max_length,
|
|
args.batch_size,
|
|
args.height,
|
|
args.width,
|
|
args.use_base_vae,
|
|
args.use_tuned,
|
|
use_lora=args.use_lora,
|
|
ondemand=args.ondemand,
|
|
)
|
|
|
|
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
|
for current_batch in range(args.batch_count):
|
|
start_time = time.time()
|
|
generated_imgs = outpaint_obj.generate_images(
|
|
args.prompts,
|
|
args.negative_prompts,
|
|
image,
|
|
args.pixels,
|
|
args.mask_blur,
|
|
args.left,
|
|
args.right,
|
|
args.top,
|
|
args.bottom,
|
|
args.noise_q,
|
|
args.color_variation,
|
|
args.batch_size,
|
|
args.height,
|
|
args.width,
|
|
args.steps,
|
|
args.guidance_scale,
|
|
seeds[current_batch],
|
|
args.max_length,
|
|
dtype,
|
|
args.use_base_vae,
|
|
cpu_scheduling,
|
|
args.max_embeddings_multiples,
|
|
)
|
|
total_time = time.time() - start_time
|
|
text_output = f"prompt={args.prompts}"
|
|
text_output += f"\nnegative prompt={args.negative_prompts}"
|
|
text_output += (
|
|
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
|
)
|
|
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
|
text_output += (
|
|
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
|
)
|
|
text_output += f"seed={seed}, size={args.height}x{args.width}"
|
|
text_output += (
|
|
f", batch size={args.batch_size}, max_length={args.max_length}"
|
|
)
|
|
text_output += outpaint_obj.log
|
|
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
|
|
|
# save this information as metadata of output generated image.
|
|
directions = []
|
|
if args.left:
|
|
directions.append("left")
|
|
if args.right:
|
|
directions.append("right")
|
|
if args.top:
|
|
directions.append("up")
|
|
if args.bottom:
|
|
directions.append("down")
|
|
extra_info = {
|
|
"PIXELS": args.pixels,
|
|
"MASK_BLUR": args.mask_blur,
|
|
"DIRECTIONS": directions,
|
|
"NOISE_Q": args.noise_q,
|
|
"COLOR_VARIATION": args.color_variation,
|
|
}
|
|
save_output_img(generated_imgs[0], seed, extra_info)
|
|
print(text_output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|