From c42ef3cd606df3accad335cd7b009e34a07a93dc Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 22 Feb 2023 15:00:26 +0000 Subject: [PATCH] [WIP] Add ControlNet to SD pipeline -- This commit adds ControlNet to SD pipeline. Signed-off-by: Abhishek Varma --- apps/README.md | 11 ++ .../src/models/model_wrappers.py | 8 +- ...pipeline_shark_stable_diffusion_txt2img.py | 76 ++++++++++ .../pipeline_shark_stable_diffusion_utils.py | 141 ++++++++++++++++-- .../src/utils/resources/base_model.json | 52 +++++++ 5 files changed, 276 insertions(+), 12 deletions(-) create mode 100644 apps/README.md diff --git a/apps/README.md b/apps/README.md new file mode 100644 index 00000000..ca8ac0e8 --- /dev/null +++ b/apps/README.md @@ -0,0 +1,11 @@ +# [WIP] ControlNet integration: +* Since this is dependent on `diffusers` and currently there's a [WIP draft PR](https://github.com/huggingface/diffusers/pull/2407), we'd need to clone [takuma104/diffusers](https://github.com/takuma104/diffusers/), checkout `controlnet` [branch](https://github.com/takuma104/diffusers/tree/controlnet) and build `diffusers` as mentioned [here](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md). +* NOTE: Run `pip uninstall diffusers` before building it from scratch in the previous step. +* Currently we've included `ControlNet` (Canny feature) as part of our SharkStableDiffusion pipeline. +* To test it you'd first need to download [control_sd15_canny.pth](https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_canny.pth) and pass it through [convert_controlnet_to_diffusers.py](https://github.com/huggingface/diffusers/blob/faf1cfbe826c88366524e92fa27b2104effdb8c4/scripts/convert_controlnet_to_diffusers.py). +* You then need to modify [this](https://github.com/nod-ai/SHARK/blob/5e5c86f4893bfdbfc2ca310803beb4ef7146213f/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py#L62) line to point to the extracted `diffusers` checkpoint directory. +* Download [sample image](https://drive.google.com/file/d/13ncuKjTO-reK-nBcLM1Lzuli5WPPB38v/view?usp=sharing) and provide the path [here](https://github.com/nod-ai/SHARK/blob/5e5c86f4893bfdbfc2ca310803beb4ef7146213f/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py#L152). +* Finally run +```shell +python apps/stable_diffusion/scripts/txt2img.py --precision=fp32 --device=cuda --prompt="bird" --max_length=64 --import_mlir --enable_stack_trace --no-use_tuned --hf_model_id="runwayml/stable-diffusion-v1-5" --scheduler="DDIM" +``` \ No newline at end of file diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index d9ef2bec..b5727291 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -227,13 +227,17 @@ class SharkifyStableDiffusionModel: self.in_channels = self.unet.in_channels self.train(False) + # TODO: Instead of flattening the `control` try to use the list. def forward( - self, latent, timestep, text_embedding, guidance_scale + self, latent, timestep, text_embedding, guidance_scale, + control1, control2, control3, control4, control5, control6, control7, + control8, control9, control10, control11, control12, control13 ): + control = [control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12, control13] # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latents = torch.cat([latent] * 2) unet_out = self.unet.forward( - latents, timestep, text_embedding, return_dict=False + latents, timestep, text_embedding, control=control, return_dict=False )[0] noise_pred_uncond, noise_pred_text = unet_out.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py index 98219631..55dba684 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py @@ -19,6 +19,9 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i StableDiffusionPipeline, ) +import cv2 +from PIL import Image + class Text2ImagePipeline(StableDiffusionPipeline): def __init__( @@ -65,6 +68,66 @@ class Text2ImagePipeline(StableDiffusionPipeline): latents = latents * self.scheduler.init_noise_sigma return latents + # TODO: Move it to a separate ControlNet based specific model functions. + def HWC3(self, x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + # TODO: Move it to a separate ControlNet based specific model functions. + def resize_image(self, input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize( + input_image, + (W, H), + interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA, + ) + return img + + # TODO: Move it to a separate ControlNet based specific model functions. + def hint_canny( + self, + image: Image.Image, + width=512, + height=512, + low_threshold=100, + high_threshold=200, + ): + with torch.no_grad(): + input_image = np.array(image) + image_resolution = width + + img = self.resize_image(self.HWC3(input_image), image_resolution) + + class CannyDetector: + def __call__(self, img, low_threshold, high_threshold): + return cv2.Canny(img, low_threshold, high_threshold) + + canny = CannyDetector() + detected_map = canny(img, low_threshold, high_threshold) + detected_map = self.HWC3(detected_map) + return detected_map + def generate_images( self, prompts, @@ -80,6 +143,18 @@ class Text2ImagePipeline(StableDiffusionPipeline): use_base_vae, cpu_scheduling, ): + # Control Embedding check & conversion + # TODO: 1. `controlnet_hint`. + # 2. Change `num_images_per_prompt`. + # 3. Supply `controlnet_img`. + from PIL import Image + + controlnet_img = Image.open("/home/abhishek/cn_1.png") + controlnet_hint = self.hint_canny(controlnet_img) + controlnet_hint = self.controlnet_hint_conversion( + controlnet_hint, height, width, num_images_per_prompt=1 + ) + # prompts and negative prompts must be a list. if isinstance(prompts, str): prompts = [prompts] @@ -122,6 +197,7 @@ class Text2ImagePipeline(StableDiffusionPipeline): total_timesteps=self.scheduler.timesteps, dtype=dtype, cpu_scheduling=cpu_scheduling, + controlnet_hint=controlnet_hint, ) # Img latents -> PIL images diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 2b6e9a2d..00c44c2e 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -55,6 +55,12 @@ class StableDiffusionPipeline: self.scheduler = scheduler # TODO: Implement using logging python utility. self.log = "" + # TODO: Make this dynamic like other models which'll be passed to StableDiffusionPipeline. + from diffusers import UNet2DConditionModel + + self.controlnet = UNet2DConditionModel.from_pretrained( + "/home/abhishek/weights/canny_weight", subfolder="controlnet" + ) def encode_prompts(self, prompts, neg_prompts, max_length): # Tokenize text and get embeddings @@ -116,6 +122,7 @@ class StableDiffusionPipeline: total_timesteps, dtype, cpu_scheduling, + controlnet_hint=None, mask=None, masked_image_latents=None, return_all_latents=False, @@ -126,7 +133,7 @@ class StableDiffusionPipeline: text_embeddings_numpy = text_embeddings.detach().numpy() for i, t in tqdm(enumerate(total_timesteps)): step_start_time = time.time() - timestep = torch.tensor([t]).to(dtype).detach().numpy() + timestep = torch.tensor([t]).to(dtype) latent_model_input = self.scheduler.scale_model_input(latents, t) if mask is not None and masked_image_latents is not None: latent_model_input = torch.cat( @@ -142,16 +149,56 @@ class StableDiffusionPipeline: # Profiling Unet. profile_device = start_profiling(file_path="unet.rdc") - noise_pred = self.unet( - "forward", - ( - latent_model_input, + if controlnet_hint is not None: + if not torch.is_tensor(latent_model_input): + latent_model_input_1 = torch.from_numpy( + np.asarray(latent_model_input) + ).to(dtype) + else: + latent_model_input_1 = latent_model_input + control = self.controlnet( + latent_model_input_1, timestep, - text_embeddings_numpy, - guidance_scale, - ), - send_to_host=False, - ) + encoder_hidden_states=text_embeddings, + controlnet_hint=controlnet_hint, + ) + timestep = timestep.detach().numpy() + # TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py. + noise_pred = self.unet( + "forward", + ( + latent_model_input, + timestep, + text_embeddings_numpy, + guidance_scale, + control[0], + control[1], + control[2], + control[3], + control[4], + control[5], + control[6], + control[7], + control[8], + control[9], + control[10], + control[11], + control[12], + ), + send_to_host=False, + ) + else: + timestep = timestep.detach().numpy() + noise_pred = self.unet( + "forward", + ( + latent_model_input, + timestep, + text_embeddings_numpy, + guidance_scale, + ), + send_to_host=False, + ) end_profiling(profile_device) if cpu_scheduling: @@ -177,6 +224,80 @@ class StableDiffusionPipeline: all_latents = torch.cat(latent_history, dim=0) return all_latents + def controlnet_hint_conversion( + self, controlnet_hint, height, width, num_images_per_prompt=1 + ): + channels = 3 + if isinstance(controlnet_hint, torch.Tensor): + # torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt) + shape_chw = (channels, height, width) + shape_bchw = (1, channels, height, width) + shape_nchw = (num_images_per_prompt, channels, height, width) + if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]: + controlnet_hint = controlnet_hint.to( + dtype=torch.float32, device=torch.device("cpu") + ) + if controlnet_hint.shape != shape_nchw: + controlnet_hint = controlnet_hint.repeat( + num_images_per_prompt, 1, 1, 1 + ) + return controlnet_hint + else: + raise ValueError( + f"Acceptble shape of `controlnet_hint` are any of ({channels}, {height}, {width})," + + f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, " + + f"{channels}, {height}, {width}) but is {controlnet_hint.shape}" + ) + elif isinstance(controlnet_hint, np.ndarray): + # np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot) + # hwc is opencv compatible image format. Color channel must be BGR Format. + if controlnet_hint.shape == (height, width): + controlnet_hint = np.repeat( + controlnet_hint[:, :, np.newaxis], channels, axis=2 + ) # hw -> hwc(c==3) + shape_hwc = (height, width, channels) + shape_bhwc = (1, height, width, channels) + shape_nhwc = (num_images_per_prompt, height, width, channels) + if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]: + controlnet_hint = torch.from_numpy(controlnet_hint.copy()) + controlnet_hint = controlnet_hint.to( + dtype=torch.float32, device=torch.device("cpu") + ) + controlnet_hint /= 255.0 + if controlnet_hint.shape != shape_nhwc: + controlnet_hint = controlnet_hint.repeat( + num_images_per_prompt, 1, 1, 1 + ) + controlnet_hint = controlnet_hint.permute( + 0, 3, 1, 2 + ) # b h w c -> b c h w + return controlnet_hint + else: + raise ValueError( + f"Acceptble shape of `controlnet_hint` are any of ({width}, {channels}), " + + f"({height}, {width}, {channels}), " + + f"(1, {height}, {width}, {channels}) or " + + f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}" + ) + elif isinstance(controlnet_hint, Image.Image): + if controlnet_hint.size == (width, height): + controlnet_hint = controlnet_hint.convert( + "RGB" + ) # make sure 3 channel RGB format + controlnet_hint = np.array(controlnet_hint) # to numpy + controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR + return self.controlnet_hint_conversion( + controlnet_hint, height, width, num_images_per_prompt + ) + else: + raise ValueError( + f"Acceptable image size of `controlnet_hint` is ({width}, {height}) but is {controlnet_hint.size}" + ) + else: + raise ValueError( + f"Acceptable type of `controlnet_hint` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}" + ) + @classmethod def from_pretrained( cls, diff --git a/apps/stable_diffusion/src/utils/resources/base_model.json b/apps/stable_diffusion/src/utils/resources/base_model.json index 041ad6ee..c2480c62 100644 --- a/apps/stable_diffusion/src/utils/resources/base_model.json +++ b/apps/stable_diffusion/src/utils/resources/base_model.json @@ -83,6 +83,58 @@ "guidance_scale": { "shape": 2, "dtype": "f32" + }, + "control1": { + "shape": [1, 320, 64, 64], + "dtype": "f32" + }, + "control2": { + "shape": [1, 320, 64, 64], + "dtype": "f32" + }, + "control3": { + "shape": [1, 320, 64, 64], + "dtype": "f32" + }, + "control4": { + "shape": [1, 320, 32, 32], + "dtype": "f32" + }, + "control5": { + "shape": [1, 640, 32, 32], + "dtype": "f32" + }, + "control6": { + "shape": [1, 640, 32, 32], + "dtype": "f32" + }, + "control7": { + "shape": [1, 640, 16, 16], + "dtype": "f32" + }, + "control8": { + "shape": [1, 1280, 16, 16], + "dtype": "f32" + }, + "control9": { + "shape": [1, 1280, 16, 16], + "dtype": "f32" + }, + "control10": { + "shape": [1, 1280, 8, 8], + "dtype": "f32" + }, + "control11": { + "shape": [1, 1280, 8, 8], + "dtype": "f32" + }, + "control12": { + "shape": [1, 1280, 8, 8], + "dtype": "f32" + }, + "control13": { + "shape": [1, 1280, 8, 8], + "dtype": "f32" } }, "vae_encode": {