Compare commits

...

1 Commits

Author SHA1 Message Date
Ean Garvey
6dfded3759 Fix batch count 2024-05-30 14:32:31 -04:00
2 changed files with 30 additions and 25 deletions

View File

@@ -272,7 +272,7 @@ class StableDiffusion:
def shark_sd_fn_dict_input(
sd_kwargs: dict,
):
print("[LOG] Submitting Request...")
print("\n[LOG] Submitting Request...")
for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
@@ -282,9 +282,8 @@ def shark_sd_fn_dict_input(
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
for i in range(1):
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
yield generated_imgs
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs
def shark_sd_fn(
@@ -412,22 +411,27 @@ def shark_sd_fn(
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
if not isinstance(out_imgs, list):
out_imgs = [out_imgs]
# total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
# print(f"\n[LOG] {text_output}")
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
save_output_img(
out_imgs[current_batch],
seed,
sd_kwargs,
)
for batch in range(batch_size):
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
# TODO: make seed changes over batch counts more configurable.
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
return generated_imgs, ""
return (generated_imgs, "")
def cancel_sd():

View File

@@ -587,21 +587,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
object_fit="fit",
preview=True,
)
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
with gr.Row():
batch_count = gr.Slider(
1,
@@ -718,6 +703,22 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
pull_kwargs = dict(
fn=pull_sd_configs,