diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 01f2f7af..ee3da655 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -190,9 +190,6 @@ class SharkifyStableDiffusionModel: ) self.model_id = model_id if custom_weights == "" else custom_weights - # TODO: remove the following line when stable-diffusion-2-1 works - if self.model_id == "stabilityai/stable-diffusion-2-1": - self.model_id = "stabilityai/stable-diffusion-2-1-base" self.custom_vae = custom_vae self.precision = precision self.base_vae = use_base_vae @@ -208,6 +205,7 @@ class SharkifyStableDiffusionModel: + "_" + precision ) + self.model_namedata = self.model_name print(f"use_tuned? sharkify: {use_tuned}") self.use_tuned = use_tuned if use_tuned: @@ -272,7 +270,11 @@ class SharkifyStableDiffusionModel: stencil_names = [] for i, stencil in enumerate(self.stencils): if stencil is not None: - cnet_config = model_config + stencil.split("_")[-1] + cnet_config = ( + self.model_namedata + + "_v1-5" + + stencil.split("_")[-1] + ) stencil_names.append( get_extended_name(sub_model + cnet_config) ) @@ -539,7 +541,7 @@ class SharkifyStableDiffusionModel: ) if use_lora != "": update_lora_weight(self.unet, use_lora, "unet") - self.in_channels = self.unet.in_channels + self.in_channels = self.unet.config.in_channels self.train(False) def forward( @@ -765,10 +767,11 @@ class SharkifyStableDiffusionModel: model_name = "stencil_adapter_512" if use_large else "stencil_adapter" ext_model_name = self.model_name[model_name] if isinstance(ext_model_name, list): + desired_name = None + print(ext_model_name) for i in ext_model_name: if stencil_id.split("_")[-1] in i: desired_name = i - print(f"Multi-CN: compiling model {i}") else: continue if desired_name: diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py index ba1d52fd..9ff4c123 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py @@ -162,6 +162,7 @@ class Image2ImagePipeline(StableDiffusionPipeline): images, resample_type, control_mode, + preprocessed_hints=[], ): # prompts and negative prompts must be a list. if isinstance(prompts, str): diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py index 386feeae..3b934723 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py @@ -25,7 +25,10 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( StableDiffusionPipeline, ) -from apps.stable_diffusion.src.utils import controlnet_hint_conversion +from apps.stable_diffusion.src.utils import ( + controlnet_hint_conversion, + controlnet_hint_reshaping, +) from apps.stable_diffusion.src.utils import ( start_profiling, end_profiling, @@ -213,7 +216,7 @@ class StencilPipeline(StableDiffusionPipeline): ) for i, controlnet_hint in enumerate(stencil_hints): if controlnet_hint is None: - continue + pass if text_embeddings.shape[1] <= self.model_max_length: control = self.controlnet[i]( "forward", @@ -300,7 +303,6 @@ class StencilPipeline(StableDiffusionPipeline): send_to_host=False, ) else: - print(self.unet_512) noise_pred = self.unet_512( "forward", ( @@ -389,13 +391,31 @@ class StencilPipeline(StableDiffusionPipeline): stencil_images, resample_type, control_mode, + preprocessed_hints, ): # Control Embedding check & conversion # controlnet_hint = controlnet_hint_conversion( # image, use_stencil, height, width, dtype, num_images_per_prompt=1 # ) stencil_hints = [] + for i, hint in enumerate(preprocessed_hints): + if hint is not None: + hint = controlnet_hint_reshaping( + hint, + height, + width, + dtype, + num_images_per_prompt=1, + ) + stencil_hints.append(hint) + for i, stencil in enumerate(stencils): + if stencil == None: + continue + if len(stencil_hints) >= i: + if stencil_hints[i] is not None: + print(f"Using preprocessed controlnet hint for {stencil}") + continue image = stencil_images[i] stencil_hints.append( controlnet_hint_conversion( diff --git a/apps/stable_diffusion/src/utils/__init__.py b/apps/stable_diffusion/src/utils/__init__.py index 7f067c69..8436e655 100644 --- a/apps/stable_diffusion/src/utils/__init__.py +++ b/apps/stable_diffusion/src/utils/__init__.py @@ -13,6 +13,7 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation from apps.stable_diffusion.src.utils.stable_args import args from apps.stable_diffusion.src.utils.stencils.stencil_utils import ( controlnet_hint_conversion, + controlnet_hint_reshaping, get_stencil_model_id, ) from apps.stable_diffusion.src.utils.utils import ( diff --git a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py index ba6d5b06..8a7466db 100644 --- a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py +++ b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py @@ -20,9 +20,7 @@ def save_img(img): get_generated_imgs_todays_subdir, ) - subdir = Path( - get_generated_imgs_path(), get_generated_imgs_todays_subdir() - ) + subdir = Path(get_generated_imgs_path(), "preprocessed_control_hints") os.makedirs(subdir, exist_ok=True) if isinstance(img, Image.Image): img.save( @@ -60,7 +58,7 @@ def HWC3(x): return y -def controlnet_hint_shaping( +def controlnet_hint_reshaping( controlnet_hint, height, width, dtype, num_images_per_prompt=1 ): channels = 3 @@ -79,7 +77,7 @@ def controlnet_hint_shaping( ) return controlnet_hint else: - return controlnet_hint_shaping( + return controlnet_hint_reshaping( Image.fromarray(controlnet_hint.detach().numpy()), height, width, @@ -111,7 +109,7 @@ def controlnet_hint_shaping( ) # b h w c -> b c h w return controlnet_hint else: - return controlnet_hint_shaping( + return controlnet_hint_reshaping( Image.fromarray(controlnet_hint), height, width, @@ -128,14 +126,16 @@ def controlnet_hint_shaping( np.float16 ) # to numpy controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR - return + return controlnet_hint_reshaping( + controlnet_hint, height, width, dtype, num_images_per_prompt + ) else: (hint_w, hint_h) = controlnet_hint.size left = int((hint_w - width) / 2) right = left + height controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h)) controlnet_hint = controlnet_hint.resize((width, height)) - return controlnet_hint_shaping( + return controlnet_hint_reshaping( controlnet_hint, height, width, dtype, num_images_per_prompt ) else: @@ -169,7 +169,7 @@ def controlnet_hint_conversion( controlnet_hint = hint_zoedepth(image) case _: return None - controlnet_hint = controlnet_hint_shaping( + controlnet_hint = controlnet_hint_reshaping( controlnet_hint, height, width, dtype, num_images_per_prompt ) return controlnet_hint diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 95dce3eb..b3b6aae8 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -1008,8 +1008,7 @@ def get_generation_text_info(seeds, device): # Both width and height should be in the range of [128, 768] and multiple of 8. # This utility function performs the transformation on the input image while # also maintaining the aspect ratio before sending it to the stencil pipeline. -def resize_stencil(image: Image.Image): - width, height = image.size +def resize_stencil(image: Image.Image, width, height): aspect_ratio = width / height min_size = min(width, height) if min_size < 128: diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index 8b35afcb..11ca040f 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -81,6 +81,7 @@ def img2img_inf( control_mode: str, stencils: list, images: list, + preprocessed_hints: list, ): from apps.stable_diffusion.web.ui.utils import ( get_custom_model_pathfile, @@ -151,7 +152,7 @@ def img2img_inf( stencil_count += 1 if stencil_count > 0: args.hf_model_id = "runwayml/stable-diffusion-v1-5" - image, width, height = resize_stencil(image) + image, _, _ = resize_stencil(image, width, height) elif "Shark" in args.scheduler: print( f"Shark schedulers are not supported. Switching to EulerDiscrete " @@ -195,7 +196,7 @@ def img2img_inf( model_id = ( args.hf_model_id if args.hf_model_id - else "stabilityai/stable-diffusion-1-5-base" + else "runwayml/stable-diffusion-v1-5" ) global_obj.set_schedulers(get_schedulers(model_id)) scheduler_obj = global_obj.get_scheduler(args.scheduler) @@ -278,6 +279,7 @@ def img2img_inf( images, resample_type=resample_type, control_mode=control_mode, + preprocessed_hints=preprocessed_hints, ) total_time = time.time() - start_time text_output = get_generation_text_info( @@ -308,6 +310,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: STENCIL_COUNT = 2 stencils = gr.State([None] * STENCIL_COUNT) images = gr.State([None] * STENCIL_COUNT) + preprocessed_hints = gr.State([None] * STENCIL_COUNT) with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) with gr.Row(): @@ -374,7 +377,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: img2img_init_image = gr.Image( label="Input Image", type="pil", - height=512, + height=300, interactive=True, ) @@ -388,10 +391,23 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: ] def cnet_preview( - model, input_image, index, stencils, images + model, + input_image, + index, + stencils, + images, + preprocessed_hints, ): + if isinstance(input_image, PIL.Image.Image): + img_dict = { + "background": None, + "layers": [None], + "composite": input_image, + } + input_image = EditorValue(img_dict) images[index] = input_image - stencils[index] = model + if model: + stencils[index] = model match model: case "canny": canny = CannyDetector() @@ -400,41 +416,75 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: 100, 200, ) + preprocessed_hints[index] = Image.fromarray( + result + ) return ( Image.fromarray(result), stencils, images, + preprocessed_hints, ) case "openpose": openpose = OpenposeDetector() result = openpose( np.array(input_image["composite"]) ) - print(result) - # TODO: This is just an empty canvas, need to draw the candidates (which are in result[1]) + preprocessed_hints[index] = Image.fromarray( + result[0] + ) return ( Image.fromarray(result[0]), stencils, images, + preprocessed_hints, ) case "zoedepth": zoedepth = ZoeDetector() result = zoedepth( np.array(input_image["composite"]) ) + preprocessed_hints[index] = Image.fromarray( + result + ) return ( Image.fromarray(result), stencils, images, + preprocessed_hints, ) case "scribble": + preprocessed_hints[index] = input_image[ + "composite" + ] return ( input_image["composite"], stencils, images, + preprocessed_hints, ) case _: - return (None, stencils, images) + preprocessed_hints[index] = None + return ( + None, + stencils, + images, + preprocessed_hints, + ) + + def import_original(original_img, width, height): + resized_img, _, _ = resize_stencil( + original_img, width, height + ) + img_dict = { + "background": resized_img, + "layers": [resized_img], + "composite": None, + } + return gr.ImageEditor( + value=EditorValue(img_dict), + crop_size=(width, height), + ) def create_canvas(width, height): data = Image.fromarray( @@ -451,8 +501,31 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: } return EditorValue(img_dict) - def update_cn_input(model, width, height): - if model == "scribble": + def update_cn_input( + model, + width, + height, + stencils, + images, + preprocessed_hints, + index, + ): + if model == None: + stencils[index] = None + images[index] = None + preprocessed_hints[index] = None + return [ + gr.ImageEditor(value=None, visible=False), + gr.Image(value=None), + gr.Slider(visible=False), + gr.Slider(visible=False), + gr.Button(visible=False), + gr.Button(visible=False), + stencils, + images, + preprocessed_hints, + ] + elif model == "scribble": return [ gr.ImageEditor( visible=True, @@ -460,20 +533,25 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: show_label=False, image_mode="RGB", type="pil", - value=create_canvas(width, height), brush=Brush( - colors=["#000000"], color_mode="fixed" + colors=["#000000"], + color_mode="fixed", + default_size=2, ), ), gr.Image( visible=True, show_label=False, - interactive=False, + interactive=True, show_download_button=False, ), - gr.Slider(visible=True), - gr.Slider(visible=True), + gr.Slider(visible=True, label="Canvas Width"), + gr.Slider(visible=True, label="Canvas Height"), gr.Button(visible=True), + gr.Button(visible=False), + stencils, + images, + preprocessed_hints, ] else: return [ @@ -482,7 +560,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: image_mode="RGB", type="pil", interactive=True, - value=None, ), gr.Image( visible=True, @@ -490,9 +567,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: interactive=True, show_download_button=False, ), - gr.Slider(visible=False), - gr.Slider(visible=False), + gr.Slider(visible=True, label="Input Width"), + gr.Slider(visible=True, label="Input Height"), gr.Button(visible=False), + gr.Button(visible=True), + stencils, + images, + preprocessed_hints, ] with gr.Row(): @@ -525,51 +606,85 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: value="Make Canvas!", visible=False, ) + use_input_img_1 = gr.Button( + value="Use Original Image", + visible=False, + ) cnet_1_image = gr.ImageEditor( visible=False, image_mode="RGB", interactive=True, - show_label=False, + show_label=True, + label="Input Image", type="pil", ) cnet_1_output = gr.Image( - visible=True, show_label=False + value=None, + visible=True, + label="Preprocessed Hint", + interactive=True, + ) + + use_input_img_1.click( + import_original, + [img2img_init_image, canvas_width, canvas_height], + [cnet_1_image], ) cnet_1_model.input( - update_cn_input, - [cnet_1_model, canvas_width, canvas_height], - [ + fn=( + lambda m, w, h, s, i, p: update_cn_input( + m, w, h, s, i, p, 0 + ) + ), + inputs=[ + cnet_1_model, + canvas_width, + canvas_height, + stencils, + images, + preprocessed_hints, + ], + outputs=[ cnet_1_image, cnet_1_output, canvas_width, canvas_height, make_canvas, + use_input_img_1, + stencils, + images, + preprocessed_hints, ], ) make_canvas.click( - update_cn_input, - [cnet_1_model, canvas_width, canvas_height], + create_canvas, + [canvas_width, canvas_height], [ cnet_1_image, - cnet_1_output, - canvas_width, - canvas_height, - make_canvas, ], ) - cnet_1.click( + gr.on( + triggers=[cnet_1.click], fn=( - lambda a, b, s, i: cnet_preview(a, b, 0, s, i) + lambda a, b, s, i, p: cnet_preview( + a, b, 0, s, i, p + ) ), inputs=[ cnet_1_model, cnet_1_image, stencils, images, + preprocessed_hints, + ], + outputs=[ + cnet_1_output, + stencils, + images, + preprocessed_hints, ], - outputs=[cnet_1_output, stencils, images], ) with gr.Row(): with gr.Column(): @@ -601,49 +716,81 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: value="Make Canvas!", visible=False, ) + use_input_img_2 = gr.Button( + value="Use Original Image", + visible=False, + ) cnet_2_image = gr.ImageEditor( visible=False, image_mode="RGB", interactive=True, - show_label=False, type="pil", + show_label=True, + label="Input Image", + ) + use_input_img_2.click( + import_original, + [img2img_init_image, canvas_width, canvas_height], + [cnet_2_image], ) cnet_2_output = gr.Image( - visible=True, show_label=False + value=None, + visible=True, + label="Preprocessed Hint", + interactive=True, ) cnet_2_model.select( - update_cn_input, - [cnet_2_model, canvas_width, canvas_height], - [ + fn=( + lambda m, w, h, s, i, p: update_cn_input( + m, w, h, s, i, p, 0 + ) + ), + inputs=[ + cnet_2_model, + canvas_width, + canvas_height, + stencils, + images, + preprocessed_hints, + ], + outputs=[ cnet_2_image, cnet_2_output, canvas_width, canvas_height, make_canvas, + use_input_img_2, + stencils, + images, + preprocessed_hints, ], ) make_canvas.click( - update_cn_input, - [cnet_2_model, canvas_width, canvas_height], + create_canvas, + [canvas_width, canvas_height], [ cnet_2_image, - cnet_2_output, - canvas_width, - canvas_height, - make_canvas, ], ) cnet_2.click( fn=( - lambda a, b, s, i: cnet_preview(a, b, 1, s, i) + lambda a, b, s, i, p: cnet_preview( + a, b, 1, s, i, p + ) ), inputs=[ cnet_2_model, cnet_2_image, stencils, images, + preprocessed_hints, + ], + outputs=[ + cnet_2_output, + stencils, + images, + preprocessed_hints, ], - outputs=[cnet_2_output, stencils, images], ) control_mode = gr.Radio( choices=["Prompt", "Balanced", "Controlnet"], @@ -865,6 +1012,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: control_mode, stencils, images, + preprocessed_hints, ], outputs=[ img2img_gallery, diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py index 12614c20..85dae66d 100644 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -348,14 +348,14 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web: scheduler = gr.Dropdown( elem_id="scheduler", label="Scheduler", - value=args.scheduler, + value="EulerDiscrete", choices=[ "DDIM", "EulerAncestralDiscrete", "EulerDiscrete", "LCMScheduler", ], - allow_custom_value=False, + allow_custom_value=True, visible=True, ) with gr.Column(): diff --git a/apps/stable_diffusion/web/ui/utils.py b/apps/stable_diffusion/web/ui/utils.py index aecf9366..9252ecee 100644 --- a/apps/stable_diffusion/web/ui/utils.py +++ b/apps/stable_diffusion/web/ui/utils.py @@ -307,10 +307,7 @@ def get_config_from_json(model_ckpt_or_id, jsonconfig): def default_config_exists(model_ckpt_or_id): - if model_ckpt_or_id in [ - "stabilityai/sdxl-turbo", - "stabilityai/stable_diffusion-xl-base-1.0", - ]: + if model_ckpt_or_id in default_configs.keys(): return model_ckpt_or_id elif "turbo" in model_ckpt_or_id.lower(): return "stabilityai/sdxl-turbo" @@ -326,7 +323,7 @@ default_configs = { value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette", ), gr.Slider(0, 10, value=2), - gr.Dropdown(value="EulerAncestralDiscrete"), + "EulerAncestralDiscrete", gr.Slider(0, value=0), 512, 512,