diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 006f74ec..78b51cb0 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -168,7 +168,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str): sd_image, sd_json["height"], sd_json["width"], - sd_json["steps"], + gr.update(), sd_json["strength"], sd_json["guidance_scale"], sd_json["seed"], @@ -238,11 +238,42 @@ def base_model_changed(base_model_id): new_choices = get_checkpoints( os.path.join("checkpoints", os.path.basename(str(base_model_id))) ) + get_checkpoints(model_type="checkpoints") + if "turbo" in base_model_id: + new_steps = gr.Dropdown( + value=cmd_opts.steps, + choices=[1, 2, 3, 4], + label="\U0001F3C3\U0000FE0F Steps", + allow_custom_value=False, + ) + if "stable-diffusion-xl-base-1.0" in base_model_id: + new_steps = gr.Dropdown( + value=40, + choices=[20, 25, 30, 35, 40, 45, 50], + label="\U0001F3C3\U0000FE0F Steps", + allow_custom_value=False, + ) + elif ".py" in base_model_id: + new_steps = gr.Dropdown( + value=20, + choices=[10, 15, 20, 28], + label="\U0001F3C3\U0000FE0F Steps", + allow_custom_value=True, + ) + else: + new_steps = gr.Dropdown( + value=cmd_opts.steps, + choices=[10, 20, 30, 40, 50], + label="\U0001F3C3\U0000FE0F Steps", + allow_custom_value=True, + ) - return gr.Dropdown( - value=new_choices[0] if len(new_choices) > 0 else "None", - choices=["None"] + new_choices, - ) + return [ + gr.Dropdown( + value=new_choices[0] if len(new_choices) > 0 else "None", + choices=["None"] + new_choices, + ), + new_steps + ] with gr.Blocks(title="Stable Diffusion") as sd_element: @@ -319,11 +350,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element: choices=["None"] + get_checkpoints(os.path.basename(str(base_model_id))), ) # custom_weights - base_model_id.change( - fn=base_model_changed, - inputs=[base_model_id], - outputs=[custom_weights], - ) sd_vae_info = (str(get_checkpoints_path("vae"))).replace( "\\", "\n\\" ) @@ -423,12 +449,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element: allow_custom_value=False, ) with gr.Row(): - steps = gr.Slider( - 1, - 50, + steps = gr.Dropdown( value=cmd_opts.steps, - step=1, + choices=[1, 2, 3, 4], label="\U0001F3C3\U0000FE0F Steps", + allow_custom_value=True, ) guidance_scale = gr.Slider( 0, @@ -728,6 +753,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element: logger.read_sd_logs, None, std_output, every=1 ) sd_status = gr.Textbox(visible=False) + base_model_id.change( + fn=base_model_changed, + inputs=[base_model_id], + outputs=[custom_weights, steps], + ) pull_kwargs = dict( fn=pull_sd_configs,