[WIP] Add ControlNet to SD pipeline

-- This commit adds ControlNet to SD pipeline.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-02-22 15:00:26 +00:00
parent 9c5415b598
commit c42ef3cd60
5 changed files with 276 additions and 12 deletions

11
apps/README.md Normal file
View File

@@ -0,0 +1,11 @@
# [WIP] ControlNet integration:
* Since this is dependent on `diffusers` and currently there's a [WIP draft PR](https://github.com/huggingface/diffusers/pull/2407), we'd need to clone [takuma104/diffusers](https://github.com/takuma104/diffusers/), checkout `controlnet` [branch](https://github.com/takuma104/diffusers/tree/controlnet) and build `diffusers` as mentioned [here](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md).
* NOTE: Run `pip uninstall diffusers` before building it from scratch in the previous step.
* Currently we've included `ControlNet` (Canny feature) as part of our SharkStableDiffusion pipeline.
* To test it you'd first need to download [control_sd15_canny.pth](https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_canny.pth) and pass it through [convert_controlnet_to_diffusers.py](https://github.com/huggingface/diffusers/blob/faf1cfbe826c88366524e92fa27b2104effdb8c4/scripts/convert_controlnet_to_diffusers.py).
* You then need to modify [this](https://github.com/nod-ai/SHARK/blob/5e5c86f4893bfdbfc2ca310803beb4ef7146213f/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py#L62) line to point to the extracted `diffusers` checkpoint directory.
* Download [sample image](https://drive.google.com/file/d/13ncuKjTO-reK-nBcLM1Lzuli5WPPB38v/view?usp=sharing) and provide the path [here](https://github.com/nod-ai/SHARK/blob/5e5c86f4893bfdbfc2ca310803beb4ef7146213f/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py#L152).
* Finally run
```shell
python apps/stable_diffusion/scripts/txt2img.py --precision=fp32 --device=cuda --prompt="bird" --max_length=64 --import_mlir --enable_stack_trace --no-use_tuned --hf_model_id="runwayml/stable-diffusion-v1-5" --scheduler="DDIM"
```

View File

@@ -227,13 +227,17 @@ class SharkifyStableDiffusionModel:
self.in_channels = self.unet.in_channels
self.train(False)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale
self, latent, timestep, text_embedding, guidance_scale,
control1, control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13
):
control = [control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12, control13]
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
latents, timestep, text_embedding, control=control, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (

View File

@@ -19,6 +19,9 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
StableDiffusionPipeline,
)
import cv2
from PIL import Image
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
@@ -65,6 +68,66 @@ class Text2ImagePipeline(StableDiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
# TODO: Move it to a separate ControlNet based specific model functions.
def HWC3(self, x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
# TODO: Move it to a separate ControlNet based specific model functions.
def resize_image(self, input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(
input_image,
(W, H),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
)
return img
# TODO: Move it to a separate ControlNet based specific model functions.
def hint_canny(
self,
image: Image.Image,
width=512,
height=512,
low_threshold=100,
high_threshold=200,
):
with torch.no_grad():
input_image = np.array(image)
image_resolution = width
img = self.resize_image(self.HWC3(input_image), image_resolution)
class CannyDetector:
def __call__(self, img, low_threshold, high_threshold):
return cv2.Canny(img, low_threshold, high_threshold)
canny = CannyDetector()
detected_map = canny(img, low_threshold, high_threshold)
detected_map = self.HWC3(detected_map)
return detected_map
def generate_images(
self,
prompts,
@@ -80,6 +143,18 @@ class Text2ImagePipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
):
# Control Embedding check & conversion
# TODO: 1. `controlnet_hint`.
# 2. Change `num_images_per_prompt`.
# 3. Supply `controlnet_img`.
from PIL import Image
controlnet_img = Image.open("/home/abhishek/cn_1.png")
controlnet_hint = self.hint_canny(controlnet_img)
controlnet_hint = self.controlnet_hint_conversion(
controlnet_hint, height, width, num_images_per_prompt=1
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
@@ -122,6 +197,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
)
# Img latents -> PIL images

View File

@@ -55,6 +55,12 @@ class StableDiffusionPipeline:
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
# TODO: Make this dynamic like other models which'll be passed to StableDiffusionPipeline.
from diffusers import UNet2DConditionModel
self.controlnet = UNet2DConditionModel.from_pretrained(
"/home/abhishek/weights/canny_weight", subfolder="controlnet"
)
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -116,6 +122,7 @@ class StableDiffusionPipeline:
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
mask=None,
masked_image_latents=None,
return_all_latents=False,
@@ -126,7 +133,7 @@ class StableDiffusionPipeline:
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
@@ -142,16 +149,56 @@ class StableDiffusionPipeline:
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
if controlnet_hint is not None:
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
latent_model_input_1,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
encoder_hidden_states=text_embeddings,
controlnet_hint=controlnet_hint,
)
timestep = timestep.detach().numpy()
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
else:
timestep = timestep.detach().numpy()
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -177,6 +224,80 @@ class StableDiffusionPipeline:
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def controlnet_hint_conversion(
self, controlnet_hint, height, width, num_images_per_prompt=1
):
channels = 3
if isinstance(controlnet_hint, torch.Tensor):
# torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
shape_chw = (channels, height, width)
shape_bchw = (1, channels, height, width)
shape_nchw = (num_images_per_prompt, channels, height, width)
if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
controlnet_hint = controlnet_hint.to(
dtype=torch.float32, device=torch.device("cpu")
)
if controlnet_hint.shape != shape_nchw:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `controlnet_hint` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
# hwc is opencv compatible image format. Color channel must be BGR Format.
if controlnet_hint.shape == (height, width):
controlnet_hint = np.repeat(
controlnet_hint[:, :, np.newaxis], channels, axis=2
) # hw -> hwc(c==3)
shape_hwc = (height, width, channels)
shape_bhwc = (1, height, width, channels)
shape_nhwc = (num_images_per_prompt, height, width, channels)
if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
controlnet_hint = torch.from_numpy(controlnet_hint.copy())
controlnet_hint = controlnet_hint.to(
dtype=torch.float32, device=torch.device("cpu")
)
controlnet_hint /= 255.0
if controlnet_hint.shape != shape_nhwc:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
controlnet_hint = controlnet_hint.permute(
0, 3, 1, 2
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `controlnet_hint` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, Image.Image):
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return self.controlnet_hint_conversion(
controlnet_hint, height, width, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable image size of `controlnet_hint` is ({width}, {height}) but is {controlnet_hint.size}"
)
else:
raise ValueError(
f"Acceptable type of `controlnet_hint` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
@classmethod
def from_pretrained(
cls,

View File

@@ -83,6 +83,58 @@
"guidance_scale": {
"shape": 2,
"dtype": "f32"
},
"control1": {
"shape": [1, 320, 64, 64],
"dtype": "f32"
},
"control2": {
"shape": [1, 320, 64, 64],
"dtype": "f32"
},
"control3": {
"shape": [1, 320, 64, 64],
"dtype": "f32"
},
"control4": {
"shape": [1, 320, 32, 32],
"dtype": "f32"
},
"control5": {
"shape": [1, 640, 32, 32],
"dtype": "f32"
},
"control6": {
"shape": [1, 640, 32, 32],
"dtype": "f32"
},
"control7": {
"shape": [1, 640, 16, 16],
"dtype": "f32"
},
"control8": {
"shape": [1, 1280, 16, 16],
"dtype": "f32"
},
"control9": {
"shape": [1, 1280, 16, 16],
"dtype": "f32"
},
"control10": {
"shape": [1, 1280, 8, 8],
"dtype": "f32"
},
"control11": {
"shape": [1, 1280, 8, 8],
"dtype": "f32"
},
"control12": {
"shape": [1, 1280, 8, 8],
"dtype": "f32"
},
"control13": {
"shape": [1, 1280, 8, 8],
"dtype": "f32"
}
},
"vae_encode": {