From be3cdec290ed2c7152ee915599587fc7a6f1ea7a Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 2 Mar 2023 00:14:40 +0530 Subject: [PATCH] [SD] Add Stencil feature to SD pipeline (#1111) * [WIP] Add ControlNet to SD pipeline -- This commit adds ControlNet to SD pipeline. Signed-off-by: Abhishek Varma * [SD] Add ControlNet to img2img + fix bug for img2img scheduler -- This commit adds ControlNet execution to img2img. -- It restructures the addition of ControlNet variants. -- It also fixes scheduler selecting bug for img2img pipeline. Signed-off-by: Abhishek Varma * add shark models for stencilSD * Add Stencil controlled SD in img2img pipeline (#1106) * use shark stencil modules * adjust diffusers change * modify to use pipeline * remove control from unet * pump stencils through unet * complete integration in img2img * fix lint and comments * [SD] Add ControlNet pipeline + integrate with WebUI + add compiled flow execution -- This commit creates a dedicated SD pipeline for ControlNet. -- Integrates it with img2img WebUI. -- Integrates the compiled execution flow for ControlNet. Signed-off-by: Abhishek Varma * [SD] Stencil execution * Remove integration setup * [SD] Fix args.use_stencil overriding bug + vmfb caching issue -- This commit fixes args.use_stencil overriding issue which caused img2img pipeline to pick wrong set of modules. -- It also fixes vmfb caching issue to speed up the loading time and pick right set of modules based on a mask. Signed-off-by: Abhishek Varma --------- Signed-off-by: Abhishek Varma Co-authored-by: Abhishek Varma Co-authored-by: PhaneeshB --- apps/stable_diffusion/scripts/img2img.py | 137 ++++++++++---- apps/stable_diffusion/src/__init__.py | 1 + .../src/models/model_wrappers.py | 171 ++++++++++++++++-- .../src/pipelines/__init__.py | 3 + ...pipeline_shark_stable_diffusion_img2img.py | 1 + ...pipeline_shark_stable_diffusion_stencil.py | 150 +++++++++++++++ ...pipeline_shark_stable_diffusion_txt2img.py | 3 + .../pipeline_shark_stable_diffusion_utils.py | 129 +++++++++++++ apps/stable_diffusion/src/utils/__init__.py | 3 + .../src/utils/resources/base_model.json | 112 +++++++++++- .../stable_diffusion/src/utils/stable_args.py | 6 + .../src/utils/stencils/canny/__init__.py | 6 + .../src/utils/stencils/stencil_utils.py | 155 ++++++++++++++++ apps/stable_diffusion/src/utils/utils.py | 19 +- apps/stable_diffusion/web/ui/img2img_ui.py | 10 +- 15 files changed, 840 insertions(+), 66 deletions(-) create mode 100644 apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py create mode 100644 apps/stable_diffusion/src/utils/stencils/canny/__init__.py create mode 100644 apps/stable_diffusion/src/utils/stencils/stencil_utils.py diff --git a/apps/stable_diffusion/scripts/img2img.py b/apps/stable_diffusion/scripts/img2img.py index 4ab491c1..627d64bc 100644 --- a/apps/stable_diffusion/scripts/img2img.py +++ b/apps/stable_diffusion/scripts/img2img.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from apps.stable_diffusion.src import ( args, Image2ImagePipeline, + StencilPipeline, get_schedulers, set_init_device_flags, utils, @@ -24,6 +25,7 @@ class Config: height: int width: int device: str + use_stencil: str img2img_obj = None @@ -50,6 +52,7 @@ def img2img_inf( precision: str, device: str, max_length: int, + use_stencil: str, save_metadata_to_json: bool, save_metadata_to_png: bool, ): @@ -92,8 +95,24 @@ def img2img_inf( args.save_metadata_to_json = save_metadata_to_json args.write_metadata_to_png = save_metadata_to_png + use_stencil = None if use_stencil == "None" else use_stencil + args.use_stencil = use_stencil + if use_stencil is not None: + args.scheduler = "DDIM" + args.hf_model_id = "runwayml/stable-diffusion-v1-5" + elif args.scheduler != "PNDM": + if "Shark" in args.scheduler: + print( + f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler" + ) + args.scheduler = "PNDM" + else: + sys.exit( + "Img2Img works best with PNDM scheduler. Other schedulers are not supported yet." + ) + cpu_scheduling = not args.scheduler.startswith("Shark") + args.precision = precision dtype = torch.float32 if precision == "fp32" else torch.half - cpu_scheduling = not scheduler.startswith("Shark") new_config_obj = Config( args.hf_model_id, args.ckpt_loc, @@ -103,10 +122,10 @@ def img2img_inf( height, width, device, + use_stencil, ) if not img2img_obj or config_obj != new_config_obj: config_obj = new_config_obj - args.precision = precision args.batch_size = batch_size args.max_length = max_length args.height = height @@ -123,21 +142,40 @@ def img2img_inf( ) schedulers = get_schedulers(model_id) scheduler_obj = schedulers[scheduler] - img2img_obj = Image2ImagePipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - ) + if use_stencil is not None: + args.use_tuned = False + img2img_obj = StencilPipeline.from_pretrained( + scheduler_obj, + args.import_mlir, + args.hf_model_id, + args.ckpt_loc, + args.custom_vae, + args.precision, + args.max_length, + args.batch_size, + args.height, + args.width, + args.use_base_vae, + args.use_tuned, + low_cpu_mem_usage=args.low_cpu_mem_usage, + use_stencil=use_stencil, + ) + else: + img2img_obj = Image2ImagePipeline.from_pretrained( + scheduler_obj, + args.import_mlir, + args.hf_model_id, + args.ckpt_loc, + args.custom_vae, + args.precision, + args.max_length, + args.batch_size, + args.height, + args.width, + args.use_base_vae, + args.use_tuned, + low_cpu_mem_usage=args.low_cpu_mem_usage, + ) img2img_obj.scheduler = schedulers[scheduler] @@ -165,6 +203,7 @@ def img2img_inf( dtype, args.use_base_vae, cpu_scheduling, + use_stencil=use_stencil, ) save_output_img(out_imgs[0], img_seed, extra_info) generated_imgs.extend(out_imgs) @@ -195,11 +234,11 @@ if __name__ == "__main__": # When the models get uploaded, it should be default to False. args.import_mlir = True - dtype = torch.float32 if args.precision == "fp32" else torch.half - cpu_scheduling = not args.scheduler.startswith("Shark") - set_init_device_flags() - schedulers = get_schedulers(args.hf_model_id) - if args.scheduler != "PNDM": + use_stencil = args.use_stencil + if use_stencil: + args.scheduler = "DDIM" + args.hf_model_id = "runwayml/stable-diffusion-v1-5" + elif args.scheduler != "PNDM": if "Shark" in args.scheduler: print( f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler" @@ -209,28 +248,49 @@ if __name__ == "__main__": sys.exit( "Img2Img works best with PNDM scheduler. Other schedulers are not supported yet." ) + cpu_scheduling = not args.scheduler.startswith("Shark") + dtype = torch.float32 if args.precision == "fp32" else torch.half + set_init_device_flags() + schedulers = get_schedulers(args.hf_model_id) scheduler_obj = schedulers[args.scheduler] image = Image.open(args.img_path).convert("RGB") seed = utils.sanitize_seed(args.seed) - # Adjust for height and width based on model - img2img_obj = Image2ImagePipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - ) + if use_stencil: + img2img_obj = StencilPipeline.from_pretrained( + scheduler_obj, + args.import_mlir, + args.hf_model_id, + args.ckpt_loc, + args.custom_vae, + args.precision, + args.max_length, + args.batch_size, + args.height, + args.width, + args.use_base_vae, + args.use_tuned, + low_cpu_mem_usage=args.low_cpu_mem_usage, + use_stencil=use_stencil, + ) + else: + img2img_obj = Image2ImagePipeline.from_pretrained( + scheduler_obj, + args.import_mlir, + args.hf_model_id, + args.ckpt_loc, + args.custom_vae, + args.precision, + args.max_length, + args.batch_size, + args.height, + args.width, + args.use_base_vae, + args.use_tuned, + low_cpu_mem_usage=args.low_cpu_mem_usage, + ) start_time = time.time() generated_imgs = img2img_obj.generate_images( @@ -248,6 +308,7 @@ if __name__ == "__main__": dtype, args.use_base_vae, cpu_scheduling, + use_stencil=use_stencil, ) total_time = time.time() - start_time text_output = f"prompt={args.prompts}" diff --git a/apps/stable_diffusion/src/__init__.py b/apps/stable_diffusion/src/__init__.py index 16276fda..37c55b52 100644 --- a/apps/stable_diffusion/src/__init__.py +++ b/apps/stable_diffusion/src/__init__.py @@ -11,5 +11,6 @@ from apps.stable_diffusion.src.pipelines import ( Image2ImagePipeline, InpaintPipeline, OutpaintPipeline, + StencilPipeline, ) from apps.stable_diffusion.src.schedulers import get_schedulers diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 05872e51..5e56e214 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -1,4 +1,4 @@ -from diffusers import AutoencoderKL, UNet2DConditionModel +from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel from transformers import CLIPTextModel from collections import defaultdict import torch @@ -117,10 +117,14 @@ class SharkifyStableDiffusionModel: self.model_name = self.model_name + "_" + get_path_stem(self.model_id) self.low_cpu_mem_usage = low_cpu_mem_usage - def get_extended_name_for_all_model(self): + def get_extended_name_for_all_model(self, mask_to_fetch): model_name = {} - sub_model_list = ["clip", "unet", "vae", "vae_encode"] + sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"] + index = 0 for model in sub_model_list: + if mask_to_fetch[index] == False: + index += 1 + continue sub_model = model model_config = self.model_name if "vae" == model: @@ -129,6 +133,7 @@ class SharkifyStableDiffusionModel: if self.base_vae: sub_model = "base_vae" model_name[model] = get_extended_name(sub_model + model_config) + index += 1 return model_name def check_params(self, max_len, width, height): @@ -215,6 +220,112 @@ class SharkifyStableDiffusionModel: ) return shark_vae + def get_controlled_unet(self): + class ControlledUnetModel(torch.nn.Module): + def __init__( + self, model_id=self.model_id, low_cpu_mem_usage=False + ): + super().__init__() + self.unet = UNet2DConditionModel.from_pretrained( + "takuma104/control_sd15_canny", # TODO: ADD with model ID + subfolder="unet", + low_cpu_mem_usage=low_cpu_mem_usage, + ) + self.in_channels = self.unet.in_channels + self.train(False) + + def forward( self, latent, timestep, text_embedding, guidance_scale, control1, + control2, control3, control4, control5, control6, control7, + control8, control9, control10, control11, control12, control13, + ): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,]) + mb_res_samples = control13 + latents = torch.cat([latent] * 2) + unet_out = self.unet.forward( + latents, + timestep, + encoder_hidden_states=text_embedding, + down_block_additional_residuals=db_res_samples, + mid_block_additional_residual=mb_res_samples, + return_dict=False, + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage) + is_f16 = True if self.precision == "fp16" else False + + inputs = tuple(self.inputs["stencil_unet"]) + input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,] + shark_controlled_unet = compile_through_fx( + unet, + inputs, + model_name=self.model_name["stencil_unet"], + is_f16=is_f16, + f16_input_mask=input_mask, + use_tuned=self.use_tuned, + extra_args=get_opt_flags("unet", precision=self.precision), + ) + return shark_controlled_unet + + def get_control_net(self): + class StencilControlNetModel(torch.nn.Module): + def __init__( + self, model_id=self.model_id, low_cpu_mem_usage=False + ): + super().__init__() + self.cnet = ControlNetModel.from_pretrained( + "takuma104/control_sd15_canny", # TODO: ADD with model ID + subfolder="controlnet", + low_cpu_mem_usage=low_cpu_mem_usage, + ) + self.in_channels = self.cnet.in_channels + self.train(False) + + def forward( + self, + latent, + timestep, + text_embedding, + stencil_image_input, + ): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + # TODO: guidance NOT NEEDED change in `get_input_info` later + latents = torch.cat( + [latent] * 2 + ) # needs to be same as controlledUNET latents + stencil_image = torch.cat( + [stencil_image_input] * 2 + ) # needs to be same as controlledUNET latents + down_block_res_samples, mid_block_res_sample = self.cnet.forward( + latents, + timestep, + encoder_hidden_states=text_embedding, + controlnet_cond=stencil_image, + return_dict=False, + ) + return tuple(list(down_block_res_samples) + [mid_block_res_sample]) + + scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage) + is_f16 = True if self.precision == "fp16" else False + + inputs = tuple(self.inputs["stencil_adaptor"]) + input_mask = [True, True, True, True] + shark_cnet = compile_through_fx( + scnet, + inputs, + model_name=self.model_name["stencil_adaptor"], + is_f16=is_f16, + f16_input_mask=input_mask, + use_tuned=self.use_tuned, + extra_args=get_opt_flags("unet", precision=self.precision), + ) + return shark_cnet + def get_unet(self): class UnetModel(torch.nn.Module): def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False): @@ -232,8 +343,9 @@ class SharkifyStableDiffusionModel: else: self.unet.set_attention_slice(args.attention_slicing) + # TODO: Instead of flattening the `control` try to use the list. def forward( - self, latent, timestep, text_embedding, guidance_scale + self, latent, timestep, text_embedding, guidance_scale, ): # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latents = torch.cat([latent] * 2) @@ -306,7 +418,7 @@ class SharkifyStableDiffusionModel: # Compiles Clip, Unet and Vae with `base_model_id` as defining their input # configiration. - def compile_all(self, base_model_id, need_vae_encode): + def compile_all(self, base_model_id, need_vae_encode, need_stencil): self.inputs = get_input_info( base_models[base_model_id], self.max_len, @@ -314,23 +426,45 @@ class SharkifyStableDiffusionModel: self.height, self.batch_size, ) - compiled_unet = self.get_unet() + compiled_controlnet = None + compiled_controlled_unet = None + compiled_unet = None + if need_stencil: + compiled_controlnet = self.get_control_net() + compiled_controlled_unet = self.get_controlled_unet() + else: + compiled_unet = self.get_unet() if self.custom_vae != "": print("Plugging in custom Vae") compiled_vae = self.get_vae() compiled_clip = self.get_clip() + + if need_stencil: + return compiled_clip, compiled_controlled_unet, compiled_vae, compiled_controlnet if need_vae_encode: compiled_vae_encode = self.get_vae_encode() return compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode - return compiled_clip, compiled_unet, compiled_vae + return compiled_clip, compiled_unet, compiled_vae, None def __call__(self): # Step 1: # -- Fetch all vmfbs for the model, if present, else delete the lot. - need_vae_encode = args.img_path is not None - self.model_name = self.get_extended_name_for_all_model() - vmfbs = fetch_or_delete_vmfbs(self.model_name, need_vae_encode, self.precision) + need_vae_encode, need_stencil = False, False + if args.img_path is not None: + if args.use_stencil is not None: + need_stencil = True + else: + need_vae_encode = True + # `mask_to_fetch` prepares a mask to pick a combination out of :- + # ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"] + mask_to_fetch = [True, True, False, True, False, False] + if need_vae_encode: + mask_to_fetch = [True, True, False, True, True, False] + elif need_stencil: + mask_to_fetch = [True, False, True, True, False, True] + self.model_name = self.get_extended_name_for_all_model(mask_to_fetch) + vmfbs = fetch_or_delete_vmfbs(self.model_name, self.precision) if vmfbs[0]: # -- If all vmfbs are indeed present, we also try and fetch the base # model configuration for running SD with custom checkpoints. @@ -339,8 +473,6 @@ class SharkifyStableDiffusionModel: if args.hf_model_id == "": sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.") print("Loaded vmfbs from cache and successfully fetched base model configuration.") - if not need_vae_encode: - return vmfbs[:3] return vmfbs # Step 2: @@ -363,7 +495,7 @@ class SharkifyStableDiffusionModel: print("Compiling all the models with the fetched base model configuration.") if args.ckpt_loc != "": args.hf_model_id = base_model_fetched - return self.compile_all(base_model_fetched, need_vae_encode) + return self.compile_all(base_model_fetched, need_vae_encode, need_stencil) # Step 3: # -- This is the retry mechanism where the base model's configuration is not @@ -372,9 +504,11 @@ class SharkifyStableDiffusionModel: for model_id in base_models: try: if need_vae_encode: - compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode) + compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode, need_stencil) + elif need_stencil: + compiled_clip, compiled_unet, compiled_vae, compiled_controlnet = self.compile_all(model_id, need_vae_encode, need_stencil) else: - compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode) + compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode, need_stencil) except Exception as e: print("Retrying with a different base model configuration") continue @@ -394,6 +528,13 @@ class SharkifyStableDiffusionModel: compiled_vae, compiled_vae_encode, ) + if need_stencil: + return ( + compiled_clip, + compiled_unet, + compiled_vae, + compiled_controlnet, + ) return compiled_clip, compiled_unet, compiled_vae sys.exit( "Cannot compile the model. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues" diff --git a/apps/stable_diffusion/src/pipelines/__init__.py b/apps/stable_diffusion/src/pipelines/__init__.py index b58b5aa0..73363225 100644 --- a/apps/stable_diffusion/src/pipelines/__init__.py +++ b/apps/stable_diffusion/src/pipelines/__init__.py @@ -10,3 +10,6 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import ( OutpaintPipeline, ) +from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import ( + StencilPipeline, +) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py index 2ef97f70..9bac1388 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py @@ -112,6 +112,7 @@ class Image2ImagePipeline(StableDiffusionPipeline): dtype, use_base_vae, cpu_scheduling, + use_stencil, ): # prompts and negative prompts must be a list. if isinstance(prompts, str): diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py new file mode 100644 index 00000000..f4917e1e --- /dev/null +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py @@ -0,0 +1,150 @@ +import torch +import time +import numpy as np +from tqdm.auto import tqdm +from random import randint +from PIL import Image +from transformers import CLIPTokenizer +from typing import Union +from shark.shark_inference import SharkInference +from diffusers import ( + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, +) +from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler +from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( + StableDiffusionPipeline, +) +from apps.stable_diffusion.src.utils import controlnet_hint_conversion + + +class StencilPipeline(StableDiffusionPipeline): + def __init__( + self, + controlnet: SharkInference, + vae: SharkInference, + text_encoder: SharkInference, + tokenizer: CLIPTokenizer, + unet: SharkInference, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + SharkEulerDiscreteScheduler, + ], + ): + super().__init__(vae, text_encoder, tokenizer, unet, scheduler) + self.controlnet = controlnet + + def prepare_latents( + self, + batch_size, + height, + width, + generator, + num_inference_steps, + dtype, + ): + latents = torch.randn( + ( + batch_size, + 4, + height // 8, + width // 8, + ), + generator=generator, + dtype=torch.float32, + ).to(dtype) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.is_scale_input_called = True + latents = latents * self.scheduler.init_noise_sigma + return latents + + def generate_images( + self, + prompts, + neg_prompts, + image, + batch_size, + height, + width, + num_inference_steps, + strength, + guidance_scale, + seed, + max_length, + dtype, + use_base_vae, + cpu_scheduling, + use_stencil, + ): + # Control Embedding check & conversion + # TODO: 1. Change `num_images_per_prompt`. + controlnet_hint = controlnet_hint_conversion( + image, use_stencil, height, width, dtype, num_images_per_prompt=1 + ) + # prompts and negative prompts must be a list. + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(neg_prompts, str): + neg_prompts = [neg_prompts] + + prompts = prompts * batch_size + neg_prompts = neg_prompts * batch_size + + # seed generator to create the inital latent noise. Also handle out of range seeds. + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + generator = torch.manual_seed(seed) + + # Get text embeddings from prompts + text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length) + + # guidance scale as a float32 tensor. + guidance_scale = torch.tensor(guidance_scale).to(torch.float32) + + # Prepare initial latent. + init_latents = self.prepare_latents( + batch_size=batch_size, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + dtype=dtype, + ) + final_timesteps = self.scheduler.timesteps + + # Get Image latents + latents = self.produce_stencil_latents( + latents=init_latents, + text_embeddings=text_embeddings, + guidance_scale=guidance_scale, + total_timesteps=final_timesteps, + dtype=dtype, + cpu_scheduling=cpu_scheduling, + controlnet_hint=controlnet_hint, + controlnet=self.controlnet, + ) + + # Img latents -> PIL images + all_imgs = [] + for i in tqdm(range(0, latents.shape[0], batch_size)): + imgs = self.decode_latents( + latents=latents[i : i + batch_size], + use_base_vae=use_base_vae, + cpu_scheduling=cpu_scheduling, + ) + all_imgs.extend(imgs) + + return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py index 2c470f7f..242ee179 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py @@ -20,6 +20,9 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i StableDiffusionPipeline, ) +import cv2 +from PIL import Image + class Text2ImagePipeline(StableDiffusionPipeline): def __init__( diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 6fb1ae3d..be8392e3 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -110,6 +110,118 @@ class StableDiffusionPipeline: pil_images = [Image.fromarray(image) for image in images.numpy()] return pil_images + def produce_stencil_latents( + self, + latents, + text_embeddings, + guidance_scale, + total_timesteps, + dtype, + cpu_scheduling, + controlnet_hint=None, + controlnet=None, + controlnet_conditioning_scale: float = 1.0, + mask=None, + masked_image_latents=None, + return_all_latents=False, + ): + step_time_sum = 0 + latent_history = [latents] + text_embeddings = torch.from_numpy(text_embeddings).to(dtype) + text_embeddings_numpy = text_embeddings.detach().numpy() + for i, t in tqdm(enumerate(total_timesteps)): + step_start_time = time.time() + timestep = torch.tensor([t]).to(dtype) + latent_model_input = self.scheduler.scale_model_input(latents, t) + if mask is not None and masked_image_latents is not None: + latent_model_input = torch.cat( + [ + torch.from_numpy(np.asarray(latent_model_input)), + mask, + masked_image_latents, + ], + dim=1, + ).to(dtype) + if cpu_scheduling: + latent_model_input = latent_model_input.detach().numpy() + + if not torch.is_tensor(latent_model_input): + latent_model_input_1 = torch.from_numpy( + np.asarray(latent_model_input) + ).to(dtype) + else: + latent_model_input_1 = latent_model_input + control = controlnet( + "forward", + ( + latent_model_input_1, + timestep, + text_embeddings, + controlnet_hint, + ), + send_to_host=False, + ) + down_block_res_samples = control[0:12] + mid_block_res_sample = control[12:] + down_block_res_samples = [ + down_block_res_sample * controlnet_conditioning_scale + for down_block_res_sample in down_block_res_samples + ] + mid_block_res_sample = ( + mid_block_res_sample[0] * controlnet_conditioning_scale + ) + timestep = timestep.detach().numpy() + # Profiling Unet. + profile_device = start_profiling(file_path="unet.rdc") + # TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py. + noise_pred = self.unet( + "forward", + ( + latent_model_input, + timestep, + text_embeddings_numpy, + guidance_scale, + down_block_res_samples[0], + down_block_res_samples[1], + down_block_res_samples[2], + down_block_res_samples[3], + down_block_res_samples[4], + down_block_res_samples[5], + down_block_res_samples[6], + down_block_res_samples[7], + down_block_res_samples[8], + down_block_res_samples[9], + down_block_res_samples[10], + down_block_res_samples[11], + mid_block_res_sample, + ), + send_to_host=False, + ) + end_profiling(profile_device) + + if cpu_scheduling: + noise_pred = torch.from_numpy(noise_pred.to_host()) + latents = self.scheduler.step( + noise_pred, t, latents + ).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents) + + latent_history.append(latents) + step_time = (time.time() - step_start_time) * 1000 + # self.log += ( + # f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms" + # ) + step_time_sum += step_time + + avg_step_time = step_time_sum / len(total_timesteps) + self.log += f"\nAverage step time: {avg_step_time}ms/it" + + if not return_all_latents: + return latents + all_latents = torch.cat(latent_history, dim=0) + return all_latents + def produce_img_latents( self, latents, @@ -205,6 +317,7 @@ class StableDiffusionPipeline: use_base_vae: bool, use_tuned: bool, low_cpu_mem_usage: bool = False, + use_stencil: bool = False, ): if import_mlir: mlir_import = SharkifyStableDiffusionModel( @@ -229,6 +342,11 @@ class StableDiffusionPipeline: return cls( vae_encode, vae, clip, get_tokenizer(), unet, scheduler ) + if cls.__name__ in ["StencilPipeline"]: + clip, unet, vae, controlnet = mlir_import() + return cls( + controlnet, vae, clip, get_tokenizer(), unet, scheduler + ) clip, unet, vae = mlir_import() return cls(vae, clip, get_tokenizer(), unet, scheduler) try: @@ -245,6 +363,12 @@ class StableDiffusionPipeline: get_unet(), scheduler, ) + if cls.__name__ == "StencilPipeline": + import sys + + sys.exit( + "StencilPipeline not supported with SharkTank currently." + ) return cls( get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler ) @@ -272,5 +396,10 @@ class StableDiffusionPipeline: return cls( vae_encode, vae, clip, get_tokenizer(), unet, scheduler ) + if cls.__name__ == "StencilPipeline": + clip, unet, vae, controlnet = mlir_import() + return cls( + controlnet, vae, clip, get_tokenizer(), unet, scheduler + ) clip, unet, vae = mlir_import() return cls(vae, clip, get_tokenizer(), unet, scheduler) diff --git a/apps/stable_diffusion/src/utils/__init__.py b/apps/stable_diffusion/src/utils/__init__.py index 6419abc3..0d5a1dd8 100644 --- a/apps/stable_diffusion/src/utils/__init__.py +++ b/apps/stable_diffusion/src/utils/__init__.py @@ -11,6 +11,9 @@ from apps.stable_diffusion.src.utils.resources import ( ) from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation from apps.stable_diffusion.src.utils.stable_args import args +from apps.stable_diffusion.src.utils.stencils.stencil_utils import ( + controlnet_hint_conversion, +) from apps.stable_diffusion.src.utils.utils import ( get_shark_model, compile_through_fx, diff --git a/apps/stable_diffusion/src/utils/resources/base_model.json b/apps/stable_diffusion/src/utils/resources/base_model.json index beb719df..9eda90ff 100644 --- a/apps/stable_diffusion/src/utils/resources/base_model.json +++ b/apps/stable_diffusion/src/utils/resources/base_model.json @@ -85,6 +85,116 @@ "dtype": "f32" } }, + "stencil_adaptor": { + "latents": { + "shape": [ + "1*batch_size", + 4, + "height", + "width" + ], + "dtype": "f32" + }, + "timesteps": { + "shape": [ + 1 + ], + "dtype": "f32" + }, + "embedding": { + "shape": [ + "2*batch_size", + "max_len", + 768 + ], + "dtype": "f32" + }, + "controlnet_hint": { + "shape": [1, 3, 512, 512], + "dtype": "f32" + } + }, + "stencil_unet": { + "latents": { + "shape": [ + "1*batch_size", + 4, + "height", + "width" + ], + "dtype": "f32" + }, + "timesteps": { + "shape": [ + 1 + ], + "dtype": "f32" + }, + "embedding": { + "shape": [ + "2*batch_size", + "max_len", + 768 + ], + "dtype": "f32" + }, + "guidance_scale": { + "shape": 2, + "dtype": "f32" + }, + "control1": { + "shape": [2, 320, 64, 64], + "dtype": "f32" + }, + "control2": { + "shape": [2, 320, 64, 64], + "dtype": "f32" + }, + "control3": { + "shape": [2, 320, 64, 64], + "dtype": "f32" + }, + "control4": { + "shape": [2, 320, 32, 32], + "dtype": "f32" + }, + "control5": { + "shape": [2, 640, 32, 32], + "dtype": "f32" + }, + "control6": { + "shape": [2, 640, 32, 32], + "dtype": "f32" + }, + "control7": { + "shape": [2, 640, 16, 16], + "dtype": "f32" + }, + "control8": { + "shape": [2, 1280, 16, 16], + "dtype": "f32" + }, + "control9": { + "shape": [2, 1280, 16, 16], + "dtype": "f32" + }, + "control10": { + "shape": [2, 1280, 8, 8], + "dtype": "f32" + }, + "control11": { + "shape": [2, 1280, 8, 8], + "dtype": "f32" + }, + "control12": { + "shape": [2, 1280, 8, 8], + "dtype": "f32" + }, + "control13": { + "shape": [2, 1280, 8, 8], + "dtype": "f32" + } + }, "vae_encode": { "image" : { "shape" : [ @@ -223,4 +333,4 @@ } } } -} +} \ No newline at end of file diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index 0ecdeaa3..a2295f3a 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -272,6 +272,12 @@ p.add_argument( help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)", ) +p.add_argument( + "--use_stencil", + choices=["canny"], + help="Enable the stencil feature.", +) + ############################################################################## ### IREE - Vulkan supported flags ############################################################################## diff --git a/apps/stable_diffusion/src/utils/stencils/canny/__init__.py b/apps/stable_diffusion/src/utils/stencils/canny/__init__.py new file mode 100644 index 00000000..cb0da951 --- /dev/null +++ b/apps/stable_diffusion/src/utils/stencils/canny/__init__.py @@ -0,0 +1,6 @@ +import cv2 + + +class CannyDetector: + def __call__(self, img, low_threshold, high_threshold): + return cv2.Canny(img, low_threshold, high_threshold) diff --git a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py new file mode 100644 index 00000000..4456b7b7 --- /dev/null +++ b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py @@ -0,0 +1,155 @@ +import cv2 +import numpy as np +from PIL import Image +import torch +from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector + +stencil = {} + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize( + input_image, + (W, H), + interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA, + ) + return img + + +def controlnet_hint_shaping( + controlnet_hint, height, width, dtype, num_images_per_prompt=1 +): + channels = 3 + if isinstance(controlnet_hint, torch.Tensor): + # torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt) + shape_chw = (channels, height, width) + shape_bchw = (1, channels, height, width) + shape_nchw = (num_images_per_prompt, channels, height, width) + if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]: + controlnet_hint = controlnet_hint.to( + dtype=dtype, device=torch.device("cpu") + ) + if controlnet_hint.shape != shape_nchw: + controlnet_hint = controlnet_hint.repeat( + num_images_per_prompt, 1, 1, 1 + ) + return controlnet_hint + else: + raise ValueError( + f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width})," + + f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, " + + f"{channels}, {height}, {width}) but is {controlnet_hint.shape}" + ) + elif isinstance(controlnet_hint, np.ndarray): + # np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot) + # hwc is opencv compatible image format. Color channel must be BGR Format. + if controlnet_hint.shape == (height, width): + controlnet_hint = np.repeat( + controlnet_hint[:, :, np.newaxis], channels, axis=2 + ) # hw -> hwc(c==3) + shape_hwc = (height, width, channels) + shape_bhwc = (1, height, width, channels) + shape_nhwc = (num_images_per_prompt, height, width, channels) + if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]: + controlnet_hint = torch.from_numpy(controlnet_hint.copy()) + controlnet_hint = controlnet_hint.to( + dtype=dtype, device=torch.device("cpu") + ) + controlnet_hint /= 255.0 + if controlnet_hint.shape != shape_nhwc: + controlnet_hint = controlnet_hint.repeat( + num_images_per_prompt, 1, 1, 1 + ) + controlnet_hint = controlnet_hint.permute( + 0, 3, 1, 2 + ) # b h w c -> b c h w + return controlnet_hint + else: + raise ValueError( + f"Acceptble shape of `stencil` are any of ({width}, {channels}), " + + f"({height}, {width}, {channels}), " + + f"(1, {height}, {width}, {channels}) or " + + f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}" + ) + elif isinstance(controlnet_hint, Image.Image): + if controlnet_hint.size == (width, height): + controlnet_hint = controlnet_hint.convert( + "RGB" + ) # make sure 3 channel RGB format + controlnet_hint = np.array(controlnet_hint) # to numpy + controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR + return controlnet_hint_shaping( + controlnet_hint, height, width, num_images_per_prompt + ) + else: + raise ValueError( + f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}" + ) + else: + raise ValueError( + f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}" + ) + + +def controlnet_hint_conversion( + image, use_stencil, height, width, dtype, num_images_per_prompt=1 +): + controlnet_hint = None + match use_stencil: + case "canny": + print("Detecting edge with canny") + controlnet_hint = hint_canny(image, width) + case _: + return None + controlnet_hint = controlnet_hint_shaping( + controlnet_hint, height, width, dtype, num_images_per_prompt + ) + return controlnet_hint + + +# Stencil 1. Canny +def hint_canny( + image: Image.Image, + width=512, + height=512, + low_threshold=100, + high_threshold=200, +): + with torch.no_grad(): + input_image = np.array(image) + image_resolution = width + + img = resize_image(HWC3(input_image), image_resolution) + + if not "canny" in stencil: + stencil["canny"] = CannyDetector() + detected_map = stencil["canny"](img, low_threshold, high_threshold) + detected_map = HWC3(detected_map) + return detected_map diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 22111c7f..0256dcea 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -450,6 +450,7 @@ def preprocessCKPT(custom_weights): def load_vmfb(vmfb_path, model, precision): model = "vae" if "base_vae" in model or "vae_encode" in model else model + model = "unet" if "stencil" in model else model precision = "fp32" if "clip" in model else precision extra_args = get_opt_flags(model, precision) shark_module = SharkInference(mlir_module=None, device=args.device) @@ -459,32 +460,28 @@ def load_vmfb(vmfb_path, model, precision): # This utility returns vmfbs of Clip, Unet, Vae and Vae_encode, in case all of them # are present; deletes them otherwise. -def fetch_or_delete_vmfbs( - extended_model_name, need_vae_encode, precision="fp32" -): +def fetch_or_delete_vmfbs(extended_model_name, precision="fp32"): vmfb_path = [ get_vmfb_path_name(extended_model_name[model]) for model in extended_model_name ] + number_of_vmfbs = len(vmfb_path) vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path] all_vmfb_present = True - compiled_models = [] - for i in range(3): + compiled_models = [None] * number_of_vmfbs + + for i in range(number_of_vmfbs): all_vmfb_present = all_vmfb_present and vmfb_present[i] - compiled_models.append(None) - if need_vae_encode: - all_vmfb_present = all_vmfb_present and vmfb_present[3] - compiled_models.append(None) # We need to delete vmfbs only if some of the models were compiled. if not all_vmfb_present: - for i in range(len(compiled_models)): + for i in range(number_of_vmfbs): if vmfb_present[i]: os.remove(vmfb_path[i]) print("Deleted: ", vmfb_path[i]) else: model_name = [model for model in extended_model_name.keys()] - for i in range(len(compiled_models)): + for i in range(number_of_vmfbs): compiled_models[i] = load_vmfb( vmfb_path[i], model_name[i], precision ) diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index 2d1af250..c72ba993 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -79,6 +79,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: height=300 ) + with gr.Accordion(label="Stencil Options", open=False): + with gr.Row(): + use_stencil = gr.Dropdown( + label="Stencil model", + value="None", + choices=["None", "canny"], + ) with gr.Accordion(label="Advanced Options", open=False): with gr.Row(): scheduler = gr.Dropdown( @@ -116,7 +123,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: "fp16", "fp32", ], - visible=False, + visible=True, ) max_length = gr.Radio( label="Max Length", @@ -221,6 +228,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: precision, device, max_length, + use_stencil, save_metadata_to_json, save_metadata_to_png, ],