diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index a0905587..c9e062e3 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -272,7 +272,7 @@ class StableDiffusion: def shark_sd_fn_dict_input( sd_kwargs: dict, ): - print("[LOG] Submitting Request...") + print("\n[LOG] Submitting Request...") for key in sd_kwargs: if sd_kwargs[key] in [None, []]: @@ -282,9 +282,8 @@ def shark_sd_fn_dict_input( if key == "seed": sd_kwargs[key] = int(sd_kwargs[key]) - for i in range(1): - generated_imgs = yield from shark_sd_fn(**sd_kwargs) - yield generated_imgs + generated_imgs = yield from shark_sd_fn(**sd_kwargs) + return generated_imgs def shark_sd_fn( @@ -412,22 +411,27 @@ def shark_sd_fn( for current_batch in range(batch_count): start_time = time.time() out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs) + if not isinstance(out_imgs, list): + out_imgs = [out_imgs] # total_time = time.time() - start_time # text_output = f"Total image(s) generation time: {total_time:.4f}sec" # print(f"\n[LOG] {text_output}") # if global_obj.get_sd_status() == SD_STATE_CANCEL: # break # else: - save_output_img( - out_imgs[current_batch], - seed, - sd_kwargs, - ) + for batch in range(batch_size): + save_output_img( + out_imgs[batch], + seed, + sd_kwargs, + ) generated_imgs.extend(out_imgs) + # TODO: make seed changes over batch counts more configurable. + submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1 yield generated_imgs, status_label( "Stable Diffusion", current_batch + 1, batch_count, batch_size ) - return generated_imgs, "" + return (generated_imgs, "") def cancel_sd(): diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index ee8bf77f..fc018dbb 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -587,21 +587,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element: object_fit="fit", preview=True, ) - with gr.Row(): - std_output = gr.Textbox( - value=f"{sd_model_info}\n" - f"Images will be saved at " - f"{get_generated_imgs_path()}", - lines=2, - elem_id="std_output", - show_label=True, - label="Log", - show_copy_button=True, - ) - sd_element.load( - logger.read_sd_logs, None, std_output, every=1 - ) - sd_status = gr.Textbox(visible=False) with gr.Row(): batch_count = gr.Slider( 1, @@ -718,6 +703,22 @@ with gr.Blocks(title="Stable Diffusion") as sd_element: inputs=[sd_json, sd_config_name], outputs=[sd_config_name], ) + with gr.Tab(label="Log", id=103) as sd_tab_log: + with gr.Row(): + std_output = gr.Textbox( + value=f"{sd_model_info}\n" + f"Images will be saved at " + f"{get_generated_imgs_path()}", + lines=2, + elem_id="std_output", + show_label=True, + label="Log", + show_copy_button=True, + ) + sd_element.load( + logger.read_sd_logs, None, std_output, every=1 + ) + sd_status = gr.Textbox(visible=False) pull_kwargs = dict( fn=pull_sd_configs,