From ce9ce3a7c859002b0f07d76880d55eea3967e7b6 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 5 Dec 2023 03:29:18 -0600 Subject: [PATCH] (SD) Fix schedulers and multi-controlnet. (#2006) * (SD) Fixes schedulers if recieving noise preds as numpy arrays * Fix schedulers and stencil name * Multicontrolnet fixes --- .../src/models/model_wrappers.py | 66 +++++++++++------- ...pipeline_shark_stable_diffusion_stencil.py | 1 - .../pipeline_shark_stable_diffusion_utils.py | 2 +- .../src/schedulers/shark_eulerdiscrete.py | 19 ++++-- .../src/utils/resources/base_model.json | 2 +- .../src/utils/stencils/stencil_utils.py | 67 ++++++++++++------- apps/stable_diffusion/web/ui/img2img_ui.py | 37 +++++++--- 7 files changed, 127 insertions(+), 67 deletions(-) diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 863b6434..01f2f7af 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -254,8 +254,8 @@ class SharkifyStableDiffusionModel: "stencil_unet_512", "vae", "vae_encode", - "stencil_adaptor", - "stencil_adaptor_512", + "stencil_adapter", + "stencil_adapter_512", ] index = 0 for model in sub_model_list: @@ -268,11 +268,19 @@ class SharkifyStableDiffusionModel: ) if self.base_vae: sub_model = "base_vae" - # TODO: Fix this - # if "stencil_adaptor" == model and self.use_stencil is not None: - # model_config = model_config + get_path_stem(self.use_stencil) - model_name[model] = get_extended_name(sub_model + model_config) - index += 1 + if "stencil_adapter" in model: + stencil_names = [] + for i, stencil in enumerate(self.stencils): + if stencil is not None: + cnet_config = model_config + stencil.split("_")[-1] + stencil_names.append( + get_extended_name(sub_model + cnet_config) + ) + + model_name[model] = stencil_names + else: + model_name[model] = get_extended_name(sub_model + model_config) + index += 1 return model_name def check_params(self, max_len, width, height): @@ -679,6 +687,8 @@ class SharkifyStableDiffusionModel: def get_control_net(self, stencil_id, use_large=False): stencil_id = get_stencil_model_id(stencil_id) + adapter_id, base_model_safe_id, ext_model_name = (None, None, None) + print(f"Importing ControlNet adapter from {stencil_id}") class StencilControlNetModel(torch.nn.Module): def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False): @@ -687,7 +697,7 @@ class SharkifyStableDiffusionModel: model_id, low_cpu_mem_usage=low_cpu_mem_usage, ) - self.in_channels = self.cnet.in_channels + self.in_channels = self.cnet.config.in_channels self.train(False) def forward( @@ -751,7 +761,23 @@ class SharkifyStableDiffusionModel: ) is_f16 = True if self.precision == "fp16" else False - inputs = tuple(self.inputs["stencil_adaptor"]) + inputs = tuple(self.inputs["stencil_adapter"]) + model_name = "stencil_adapter_512" if use_large else "stencil_adapter" + ext_model_name = self.model_name[model_name] + if isinstance(ext_model_name, list): + for i in ext_model_name: + if stencil_id.split("_")[-1] in i: + desired_name = i + print(f"Multi-CN: compiling model {i}") + else: + continue + if desired_name: + ext_model_name = desired_name + else: + raise Exception( + f"Could not find extended configuration for {stencil_id}" + ) + if use_large: pad = (0, 0) * (len(inputs[2].shape) - 2) pad = pad + (0, 512 - inputs[2].shape[1]) @@ -761,19 +787,13 @@ class SharkifyStableDiffusionModel: torch.nn.functional.pad(inputs[2], pad), *inputs[3:], ) - save_dir = os.path.join( - self.sharktank_dir, self.model_name["stencil_adaptor_512"] - ) - else: - save_dir = os.path.join( - self.sharktank_dir, self.model_name["stencil_adaptor"] - ) + save_dir = os.path.join(self.sharktank_dir, ext_model_name) input_mask = [True, True, True, True] + ([True] * 13) - model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512" + shark_cnet, cnet_mlir = compile_through_fx( scnet, inputs, - extended_model_name=self.model_name[model_name], + extended_model_name=ext_model_name, is_f16=is_f16, f16_input_mask=input_mask, use_tuned=self.use_tuned, @@ -1315,16 +1335,16 @@ class SharkifyStableDiffusionModel: def controlnet(self, stencil_id, use_large=False): try: - self.inputs["stencil_adaptor"] = self.get_input_info_for( - base_models["stencil_adaptor"] + self.inputs["stencil_adapter"] = self.get_input_info_for( + base_models["stencil_adapter"] ) - compiled_stencil_adaptor, controlnet_mlir = self.get_control_net( + compiled_stencil_adapter, controlnet_mlir = self.get_control_net( stencil_id, use_large=use_large ) - check_compilation(compiled_stencil_adaptor, "Stencil") + check_compilation(compiled_stencil_adapter, "Stencil") if self.return_mlir: return controlnet_mlir - return compiled_stencil_adaptor + return compiled_stencil_adapter except Exception as e: sys.exit(e) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py index 4d35ea2d..386feeae 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py @@ -391,7 +391,6 @@ class StencilPipeline(StableDiffusionPipeline): control_mode, ): # Control Embedding check & conversion - # TODO: 1. Change `num_images_per_prompt`. # controlnet_hint = controlnet_hint_conversion( # image, use_stencil, height, width, dtype, num_images_per_prompt=1 # ) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 7d3d6b1c..6cfda0cb 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -825,7 +825,7 @@ class StableDiffusionPipeline: gc.collect() self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}" - return text_embeddings.numpy() + return text_embeddings.numpy().astype(np.float16) from typing import List, Optional, Union diff --git a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py index c074af4d..3c25dc40 100644 --- a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py +++ b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py @@ -207,16 +207,21 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler): sigma_hat = sigma * (gamma + 1) - noise = randn_tensor( - noise_pred.shape, - dtype=noise_pred.dtype, - device="cpu", - generator=generator, + noise_pred = ( + torch.from_numpy(noise_pred) + if isinstance(noise_pred, np.ndarray) + else noise_pred ) - eps = noise * s_noise - if gamma > 0: + noise = randn_tensor( + torch.Size(noise_pred.shape), + dtype=torch.float16, + device="cpu", + generator=generator, + ) + + eps = noise * s_noise latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5 if self.config.prediction_type == "v_prediction": diff --git a/apps/stable_diffusion/src/utils/resources/base_model.json b/apps/stable_diffusion/src/utils/resources/base_model.json index 5cbf9655..6adce2ad 100644 --- a/apps/stable_diffusion/src/utils/resources/base_model.json +++ b/apps/stable_diffusion/src/utils/resources/base_model.json @@ -276,7 +276,7 @@ } } }, - "stencil_adaptor": { + "stencil_adapter": { "latents": { "shape": [ "1*batch_size", diff --git a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py index 41800f0c..ba6d5b06 100644 --- a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py +++ b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py @@ -79,10 +79,12 @@ def controlnet_hint_shaping( ) return controlnet_hint else: - raise ValueError( - f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width})," - + f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, " - + f"{channels}, {height}, {width}) but is {controlnet_hint.shape}" + return controlnet_hint_shaping( + Image.fromarray(controlnet_hint.detach().numpy()), + height, + width, + dtype, + num_images_per_prompt, ) elif isinstance(controlnet_hint, np.ndarray): # np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot) @@ -109,29 +111,36 @@ def controlnet_hint_shaping( ) # b h w c -> b c h w return controlnet_hint else: - raise ValueError( - f"Acceptble shape of `stencil` are any of ({width}, {channels}), " - + f"({height}, {width}, {channels}), " - + f"(1, {height}, {width}, {channels}) or " - + f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}" - ) - elif isinstance(controlnet_hint, Image.Image): - if controlnet_hint.size == (width, height): - controlnet_hint = controlnet_hint.convert( - "RGB" - ) # make sure 3 channel RGB format - controlnet_hint = np.array(controlnet_hint) # to numpy - controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR return controlnet_hint_shaping( - controlnet_hint, height, width, num_images_per_prompt + Image.fromarray(controlnet_hint), + height, + width, + dtype, + num_images_per_prompt, ) + + elif isinstance(controlnet_hint, Image.Image): + controlnet_hint = controlnet_hint.convert( + "RGB" + ) # make sure 3 channel RGB format + if controlnet_hint.size == (width, height): + controlnet_hint = np.array(controlnet_hint).astype( + np.float16 + ) # to numpy + controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR + return else: - raise ValueError( - f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}" - ) + (hint_w, hint_h) = controlnet_hint.size + left = int((hint_w - width) / 2) + right = left + height + controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h)) + controlnet_hint = controlnet_hint.resize((width, height)) + return controlnet_hint_shaping( + controlnet_hint, height, width, dtype, num_images_per_prompt + ) else: raise ValueError( - f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}" + f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}" ) @@ -141,16 +150,22 @@ def controlnet_hint_conversion( controlnet_hint = None match use_stencil: case "canny": - print("Detecting edge with canny") + print( + "Converting controlnet hint to edge detection mask with canny preprocessor." + ) controlnet_hint = hint_canny(image) case "openpose": - print("Detecting human pose") + print( + "Detecting human pose in controlnet hint with openpose preprocessor." + ) controlnet_hint = hint_openpose(image) case "scribble": - print("Working with scribble") + print("Using your scribble as a controlnet hint.") controlnet_hint = hint_scribble(image) case "zoedepth": - print("Working with ZoeDepth") + print( + "Converting controlnet hint to a depth mapping with ZoeDepth." + ) controlnet_hint = hint_zoedepth(image) case _: return None diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index 1d8a7f53..d3bfe544 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -6,6 +6,12 @@ import PIL from math import ceil from PIL import Image +from gradio.components.image_editor import ( + Brush, + Eraser, + EditorData, + EditorValue, +) from apps.stable_diffusion.web.ui.utils import ( available_devices, nodlogo_loc, @@ -189,7 +195,7 @@ def img2img_inf( model_id = ( args.hf_model_id if args.hf_model_id - else "stabilityai/stable-diffusion-2-1-base" + else "stabilityai/stable-diffusion-1-5-base" ) global_obj.set_schedulers(get_schedulers(model_id)) scheduler_obj = global_obj.get_scheduler(args.scheduler) @@ -403,6 +409,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: result = openpose( np.array(input_image["composite"]) ) + print(result) # TODO: This is just an empty canvas, need to draw the candidates (which are in result[1]) return ( Image.fromarray(result[0]), @@ -415,42 +422,53 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: np.array(input_image["composite"]) ) return ( - Image.fromarray(result[0]), + Image.fromarray(result), stencils, images, ) case "scribble": - result = input_image["composite"].convert("L") - return (result, stencils, images) + return ( + input_image["composite"], + stencils, + images, + ) case _: return (None, stencils, images) def create_canvas(width, height): - data = ( + data = Image.fromarray( np.zeros( shape=(height, width, 3), dtype=np.uint8, ) + 255 ) - return data + img_dict = { + "background": data, + "layers": [data], + "composite": None, + } + return EditorValue(img_dict) def update_cn_input(model, width, height): if model == "scribble": return [ gr.ImageEditor( visible=True, - image_mode="RGB", interactive=True, show_label=False, + image_mode="RGB", type="pil", value=create_canvas(width, height), - crop_size=(width, height), + brush=Brush( + colors=["#000000"], color_mode="fixed" + ), ), gr.Image( visible=True, show_label=False, interactive=False, + show_download_button=False, ), gr.Slider(visible=True), gr.Slider(visible=True), @@ -469,6 +487,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: visible=True, show_label=False, interactive=True, + show_download_button=False, ), gr.Slider(visible=False), gr.Slider(visible=False), @@ -505,6 +524,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: value="Make Canvas!", visible=False, ) + cnet_1_image = gr.ImageEditor( visible=False, image_mode="RGB", @@ -515,6 +535,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: cnet_1_output = gr.Image( visible=True, show_label=False ) + cnet_1_model.input( update_cn_input, [cnet_1_model, canvas_width, canvas_height],