[SD] Add batch count in stable diffusion

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-02-07 21:46:34 +05:30
parent eeb20b531a
commit 5af124c5a5
2 changed files with 77 additions and 75 deletions

View File

@@ -138,6 +138,7 @@ def txt2img_inf(
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
@@ -155,7 +156,6 @@ def txt2img_inf(
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
img_seed = utils.sanitize_seed(seed)
args.steps = steps
args.scheduler = scheduler
@@ -232,30 +232,38 @@ def txt2img_inf(
start_time = time.time()
txt2img_obj.log = ""
generated_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
txt2img_obj.log += "\n"
total_time = time.time() - start_time
save_output_img(generated_imgs[0], img_seed)
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={img_seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"

View File

@@ -119,9 +119,17 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
"SharkEulerDiscrete",
],
)
batch_size = gr.Slider(
1, 4, value=1, step=1, label="Number of Images"
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=True,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=False,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 786, value=512, step=8, label="Height"
@@ -159,14 +167,20 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
label="CFG Scale",
)
with gr.Row():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=True,
batch_count = gr.Slider(
1,
10,
value=1,
step=1,
label="Batch Count",
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=False,
batch_size = gr.Slider(
1,
4,
value=1,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Row():
@@ -213,53 +227,33 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
value=output_dir,
interactive=False,
)
kwargs = dict(
fn=txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
prompt.submit(
txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
stable_diffusion.click(
txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
prompt.submit(**kwargs)
stable_diffusion.click(**kwargs)
shark_web.queue()
shark_web.launch(