mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
SD - Add repeatable (batch) seeds option (#1654)
* Generates the seeds for all batch_count batches being run up front rather than generating the seed for a batch just before it is run. * Adds a --repeatable_seeds argument defaulting to False * When repeatable_seeds=True, the first seed for a set of batches will also be used as the rng seed for the subsequent batch seeds in the run. The rng seed is then reset. * When repeatable_seeds=False, batch seeding works as currently. * Update scripts under apps/scripts that support the batch_count argument to also support the repeatable_seeds argument. * UI/Web: Adds a checkbox element on each SD tab after batch count/size for toggling repeatable seeds, and update _inf functions to take this into account. * UI/Web: Moves the Stop buttons out of the Advanced sections and next to Generate to make things not fit quite so badly with the extra UI elements. * UI/Web: Fixes logging to the upscaler output text box not working correctly when running multiple batches.
This commit is contained in:
@@ -58,11 +58,8 @@ def main():
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = inpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
@@ -76,7 +73,7 @@ def main():
|
||||
args.inpaint_full_res_padding,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
@@ -90,7 +87,10 @@ def main():
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += f"seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
|
||||
@@ -51,11 +51,8 @@ def main():
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = outpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
@@ -74,7 +71,7 @@ def main():
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
@@ -88,7 +85,10 @@ def main():
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += f"seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
|
||||
@@ -42,11 +42,8 @@ def main():
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
@@ -56,7 +53,7 @@ def main():
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
@@ -70,7 +67,12 @@ def main():
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += (
|
||||
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
|
||||
)
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
batch_seeds,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_generated_imgs_path,
|
||||
|
||||
@@ -320,10 +320,18 @@ p.add_argument(
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of batch to be generated with random seeds in "
|
||||
help="Number of batches to be generated with random seeds in "
|
||||
"single execution.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--repeatable_seeds",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="The seed of the first batch will be used as the rng seed to "
|
||||
"generate the subsequent seeds for subsequent batches in that run.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_loc",
|
||||
type=str,
|
||||
@@ -524,6 +532,8 @@ p.add_argument(
|
||||
help="If import_mlir is True, saves mlir via the debug option "
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Web UI flags
|
||||
##############################################################################
|
||||
|
||||
@@ -8,7 +8,12 @@ from datetime import datetime as dt
|
||||
from csv import DictWriter
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from random import (
|
||||
randint,
|
||||
seed as seed_random,
|
||||
getstate as random_getstate,
|
||||
setstate as random_setstate,
|
||||
)
|
||||
import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
@@ -730,6 +735,28 @@ def sanitize_seed(seed):
|
||||
return seed
|
||||
|
||||
|
||||
# Generate a set of seeds, using as the first seed of the set,
|
||||
# optionally using it as the rng seed for subsequent seeds in the set
|
||||
def batch_seeds(seed, batch_count, repeatable=False):
|
||||
# use the passed seed as the initial seed of the batch
|
||||
seeds = [sanitize_seed(seed)]
|
||||
|
||||
if repeatable:
|
||||
# use the initial seed as the rng generator seed
|
||||
saved_random_state = random_getstate()
|
||||
seed_random(seed)
|
||||
|
||||
# generate the additional seeds
|
||||
for i in range(1, batch_count):
|
||||
seeds.append(sanitize_seed(-1))
|
||||
|
||||
if repeatable:
|
||||
# reset the rng back to normal
|
||||
random_setstate(saved_random_state)
|
||||
|
||||
return seeds
|
||||
|
||||
|
||||
# clear all the cached objects to recompile cleanly.
|
||||
def clear_all():
|
||||
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
|
||||
|
||||
@@ -66,6 +66,7 @@ def img2img_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -230,12 +231,11 @@ def img2img_inf(
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
extra_info = {"STRENGTH": strength}
|
||||
text_output = ""
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -246,7 +246,7 @@ def img2img_inf(
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
@@ -254,9 +254,10 @@ def img2img_inf(
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
@@ -265,7 +266,7 @@ def img2img_inf(
|
||||
else:
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
extra_info,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
@@ -344,6 +345,7 @@ def img2img_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
@@ -565,16 +567,18 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.strength,
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.strength,
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
@@ -598,6 +602,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -607,7 +616,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -619,16 +627,15 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -683,6 +690,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -64,6 +64,7 @@ def inpaint_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: int,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -180,14 +181,12 @@ def inpaint_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -200,26 +199,27 @@ def inpaint_inf(
|
||||
inpaint_full_res_padding,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
save_output_img(out_imgs[0], seeds[current_batch])
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Inpaint", i + 1, batch_count, batch_size
|
||||
"Inpaint", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
@@ -292,6 +292,7 @@ def inpaint_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
@@ -498,6 +499,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -507,7 +513,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -519,16 +524,15 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -584,6 +588,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[inpaint_gallery, std_output, inpaint_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -64,6 +64,7 @@ def outpaint_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -177,8 +178,7 @@ def outpaint_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
|
||||
left = True if "left" in directions else False
|
||||
right = True if "right" in directions else False
|
||||
@@ -186,9 +186,7 @@ def outpaint_inf(
|
||||
bottom = True if "down" in directions else False
|
||||
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -206,26 +204,27 @@ def outpaint_inf(
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
save_output_img(out_imgs[0], seeds[current_batch])
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Outpaint", i + 1, batch_count, batch_size
|
||||
"Outpaint", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output, ""
|
||||
@@ -300,6 +299,7 @@ def outpaint_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
@@ -526,6 +526,12 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -535,7 +541,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -547,16 +552,15 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -612,6 +616,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[outpaint_gallery, std_output, outpaint_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -61,6 +61,7 @@ def txt2img_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -177,12 +178,10 @@ def txt2img_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -191,26 +190,27 @@ def txt2img_inf(
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
save_output_img(out_imgs[0], seeds[current_batch])
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Text-to-Image", i + 1, batch_count, batch_size
|
||||
"Text-to-Image", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output, ""
|
||||
@@ -267,6 +267,7 @@ def txt2img_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
@@ -439,16 +440,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
@@ -473,7 +476,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -485,17 +491,15 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
@@ -555,6 +559,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[txt2img_gallery, std_output, txt2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -57,6 +57,7 @@ def upscaler_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -176,12 +177,10 @@ def upscaler_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
extra_info = {"NOISE LEVEL": noise_level}
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
low_res_img = image
|
||||
high_res_img = Image.new("RGB", (height * 4, width * 4))
|
||||
|
||||
@@ -198,7 +197,7 @@ def upscaler_inf(
|
||||
steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
@@ -213,39 +212,40 @@ def upscaler_inf(
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={steps}, "
|
||||
f"noise_level={noise_level}, "
|
||||
f"guidance_scale={guidance_scale}, "
|
||||
f"seed={seeds[:current_batch + 1]}"
|
||||
)
|
||||
text_output += (
|
||||
f"\ninput size={height}x{width}, "
|
||||
f"output size={height*4}x{width*4}, "
|
||||
f"batch_count={batch_count}, "
|
||||
f"batch_size={batch_size}, "
|
||||
f"max_length={args.max_length}\n"
|
||||
)
|
||||
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
save_output_img(high_res_img, seeds[current_batch], extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log, status_label(
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Upscaler", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={steps}, "
|
||||
f"noise_level={noise_level}, "
|
||||
f"guidance_scale={guidance_scale}, "
|
||||
f"seed={seeds}"
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={height}x{width}, "
|
||||
f"batch_count={batch_count}, "
|
||||
f"batch_size={batch_size}, "
|
||||
f"max_length={args.max_length}"
|
||||
)
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
yield generated_imgs, text_output, ""
|
||||
|
||||
|
||||
@@ -314,6 +314,7 @@ def upscaler_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
@@ -518,6 +519,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -527,7 +533,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -539,16 +544,15 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -601,6 +605,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[upscaler_gallery, std_output, upscaler_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
Reference in New Issue
Block a user