Compare commits

...

4 Commits

Author SHA1 Message Date
Phaneesh Barwaria
604c0f81b9 Update model_wrappers.py
remove exception
2023-02-24 19:55:48 +05:30
PhaneeshB
3366a62dbc add shark models for stencilSD 2023-02-24 19:38:42 +05:30
Abhishek Varma
dee41453b5 [SD] Add ControlNet to img2img + fix bug for img2img scheduler
-- This commit adds ControlNet execution to img2img.
-- It restructures the addition of ControlNet variants.
-- It also fixes scheduler selecting bug for img2img pipeline.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-24 19:07:19 +05:30
Abhishek Varma
c42ef3cd60 [WIP] Add ControlNet to SD pipeline
-- This commit adds ControlNet to SD pipeline.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-22 15:06:27 +00:00
11 changed files with 605 additions and 30 deletions

11
apps/README.md Normal file
View File

@@ -0,0 +1,11 @@
# [WIP] ControlNet integration:
* Since this is dependent on `diffusers` and currently there's a [WIP draft PR](https://github.com/huggingface/diffusers/pull/2407), we'd need to clone [takuma104/diffusers](https://github.com/takuma104/diffusers/), checkout `controlnet` [branch](https://github.com/takuma104/diffusers/tree/controlnet) and build `diffusers` as mentioned [here](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md).
* NOTE: Run `pip uninstall diffusers` before building it from scratch in the previous step.
* Currently we've included `ControlNet` (Canny feature) as part of our SharkStableDiffusion pipeline.
* To test it you'd first need to download [control_sd15_canny.pth](https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_canny.pth) and pass it through [convert_controlnet_to_diffusers.py](https://github.com/huggingface/diffusers/blob/faf1cfbe826c88366524e92fa27b2104effdb8c4/scripts/convert_controlnet_to_diffusers.py).
* You then need to modify [this](https://github.com/nod-ai/SHARK/blob/5e5c86f4893bfdbfc2ca310803beb4ef7146213f/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py#L62) line to point to the extracted `diffusers` checkpoint directory.
* Download [sample image](https://drive.google.com/file/d/13ncuKjTO-reK-nBcLM1Lzuli5WPPB38v/view?usp=sharing) and provide the path [here](https://github.com/nod-ai/SHARK/blob/5e5c86f4893bfdbfc2ca310803beb4ef7146213f/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py#L152).
* Finally run
```shell
python apps/stable_diffusion/scripts/txt2img.py --precision=fp32 --device=cuda --prompt="bird" --max_length=64 --import_mlir --enable_stack_trace --no-use_tuned --hf_model_id="runwayml/stable-diffusion-v1-5" --scheduler="DDIM"
```

View File

@@ -195,10 +195,10 @@ if __name__ == "__main__":
args.import_mlir = True
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
if args.scheduler != "PNDM":
use_stencil = args.use_stencil
if use_stencil:
args.scheduler = "DDIM"
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
@@ -208,11 +208,13 @@ if __name__ == "__main__":
sys.exit(
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
)
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = Image.open(args.img_path).convert("RGB")
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model
img2img_obj = Image2ImagePipeline.from_pretrained(
@@ -229,6 +231,7 @@ if __name__ == "__main__":
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
)
start_time = time.time()
@@ -247,6 +250,7 @@ if __name__ == "__main__":
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -119,7 +119,7 @@ class SharkifyStableDiffusionModel:
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "vae", "vae_encode"]
sub_model_list = ["clip", "unet", "vae", "vae_encode", "controlnet", "controlled_unet"]
for model in sub_model_list:
sub_model = model
model_config = self.model_name
@@ -215,6 +215,109 @@ class SharkifyStableDiffusionModel:
)
return shark_vae
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
"takuma104/control_sd15_canny", # TODO: ADD with model ID
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
control = [ control1, control2, control3, control4, control5,
control6, control7, control8, control9, control10,
control11, control12, control13,]
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
timestep,
text_embedding,
return_dict=False,
control=control,
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["controlled_unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
shark_controlled_unet = compile_through_fx(
unet,
inputs,
model_name=self.model_name["controlled_unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_controlled_unet
def get_control_net(self):
class ControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
"takuma104/control_sd15_canny", # TODO: ADD with model ID
subfolder="controlnet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(
self,
latent,
timestep,
text_embedding,
controlnet_hint,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
latents = torch.cat(
[latent] * 2
) # needs to be same as controlledUNET latents
controlnet_out = self.unet.forward(
latents,
timestep,
text_embedding,
return_dict=False,
controlnet_hint=controlnet_hint,
)
return tuple(controlnet_out)
cunet = ControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["controlnet"])
input_mask = [True, True, True, True]
shark_unet = compile_through_fx(
cunet,
inputs,
model_name=self.model_name["controlnet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
@@ -227,13 +330,17 @@ class SharkifyStableDiffusionModel:
self.in_channels = self.unet.in_channels
self.train(False)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale
self, latent, timestep, text_embedding, guidance_scale,
control1, control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13
):
control = [control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12, control13]
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
latents, timestep, text_embedding, control=control, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
@@ -309,6 +416,8 @@ class SharkifyStableDiffusionModel:
self.height,
self.batch_size,
)
compiled_controlnet = self.get_control_net()
compiled_controlled_unet = self.get_controlled_unet()
compiled_unet = self.get_unet()
if self.custom_vae != "":
print("Plugging in custom Vae")

View File

@@ -19,6 +19,7 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
class Image2ImagePipeline(StableDiffusionPipeline):
@@ -42,6 +43,31 @@ class Image2ImagePipeline(StableDiffusionPipeline):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_image_latents(
self,
image,
@@ -110,7 +136,13 @@ class Image2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
use_stencil,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, num_images_per_prompt=1
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
@@ -134,26 +166,40 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare input image latent
image_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
)
# Prepare initial latent.
init_latents = None
final_timesteps = None
if controlnet_hint is not None:
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps
else:
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=image_latents,
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
)
# Img latents -> PIL images

View File

@@ -19,6 +19,9 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
StableDiffusionPipeline,
)
import cv2
from PIL import Image
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(

View File

@@ -55,6 +55,12 @@ class StableDiffusionPipeline:
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
# TODO: Make this dynamic like other models which'll be passed to StableDiffusionPipeline.
from diffusers import UNet2DConditionModel
self.controlnet = UNet2DConditionModel.from_pretrained(
"/home/abhishek/weights/canny_weight", subfolder="controlnet"
)
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -116,6 +122,7 @@ class StableDiffusionPipeline:
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
mask=None,
masked_image_latents=None,
return_all_latents=False,
@@ -126,7 +133,7 @@ class StableDiffusionPipeline:
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
timestep = torch.tensor([t]).to(dtype)
latent_model_input = self.scheduler.scale_model_input(latents, t)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
@@ -142,16 +149,56 @@ class StableDiffusionPipeline:
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
if controlnet_hint is not None:
if not torch.is_tensor(latent_model_input):
latent_model_input_1 = torch.from_numpy(
np.asarray(latent_model_input)
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
latent_model_input_1,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
encoder_hidden_states=text_embeddings,
controlnet_hint=controlnet_hint,
)
timestep = timestep.detach().numpy()
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
else:
timestep = timestep.detach().numpy()
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -177,6 +224,80 @@ class StableDiffusionPipeline:
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def controlnet_hint_conversion(
self, controlnet_hint, height, width, num_images_per_prompt=1
):
channels = 3
if isinstance(controlnet_hint, torch.Tensor):
# torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
shape_chw = (channels, height, width)
shape_bchw = (1, channels, height, width)
shape_nchw = (num_images_per_prompt, channels, height, width)
if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
controlnet_hint = controlnet_hint.to(
dtype=torch.float32, device=torch.device("cpu")
)
if controlnet_hint.shape != shape_nchw:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `controlnet_hint` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
# hwc is opencv compatible image format. Color channel must be BGR Format.
if controlnet_hint.shape == (height, width):
controlnet_hint = np.repeat(
controlnet_hint[:, :, np.newaxis], channels, axis=2
) # hw -> hwc(c==3)
shape_hwc = (height, width, channels)
shape_bhwc = (1, height, width, channels)
shape_nhwc = (num_images_per_prompt, height, width, channels)
if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
controlnet_hint = torch.from_numpy(controlnet_hint.copy())
controlnet_hint = controlnet_hint.to(
dtype=torch.float32, device=torch.device("cpu")
)
controlnet_hint /= 255.0
if controlnet_hint.shape != shape_nhwc:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
controlnet_hint = controlnet_hint.permute(
0, 3, 1, 2
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `controlnet_hint` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, Image.Image):
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return self.controlnet_hint_conversion(
controlnet_hint, height, width, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable image size of `controlnet_hint` is ({width}, {height}) but is {controlnet_hint.size}"
)
else:
raise ValueError(
f"Acceptable type of `controlnet_hint` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
@classmethod
def from_pretrained(
cls,
@@ -202,6 +323,7 @@ class StableDiffusionPipeline:
use_base_vae: bool,
use_tuned: bool,
low_cpu_mem_usage: bool = False,
use_stencil: bool = False,
):
if import_mlir:
mlir_import = SharkifyStableDiffusionModel(

View File

@@ -11,6 +11,9 @@ from apps.stable_diffusion.src.utils.resources import (
)
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
controlnet_hint_conversion,
)
from apps.stable_diffusion.src.utils.utils import (
get_shark_model,
compile_through_fx,

View File

@@ -85,6 +85,116 @@
"dtype": "f32"
}
},
"controlnet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, 512, 512],
"dtype": "f32"
}
},
"controlled_unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
},
"control1": {
"shape": [1, 320, 64, 64],
"dtype": "f32"
},
"control2": {
"shape": [1, 320, 64, 64],
"dtype": "f32"
},
"control3": {
"shape": [1, 320, 64, 64],
"dtype": "f32"
},
"control4": {
"shape": [1, 320, 32, 32],
"dtype": "f32"
},
"control5": {
"shape": [1, 640, 32, 32],
"dtype": "f32"
},
"control6": {
"shape": [1, 640, 32, 32],
"dtype": "f32"
},
"control7": {
"shape": [1, 640, 16, 16],
"dtype": "f32"
},
"control8": {
"shape": [1, 1280, 16, 16],
"dtype": "f32"
},
"control9": {
"shape": [1, 1280, 16, 16],
"dtype": "f32"
},
"control10": {
"shape": [1, 1280, 8, 8],
"dtype": "f32"
},
"control11": {
"shape": [1, 1280, 8, 8],
"dtype": "f32"
},
"control12": {
"shape": [1, 1280, 8, 8],
"dtype": "f32"
},
"control13": {
"shape": [1, 1280, 8, 8],
"dtype": "f32"
}
},
"vae_encode": {
"image" : {
"shape" : [

View File

@@ -200,6 +200,12 @@ p.add_argument(
help="Use the accelerate package to reduce cpu memory consumption",
)
p.add_argument(
"--use_stencil",
choices=["canny"],
help="Enable the stencil feature.",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################

View File

@@ -0,0 +1,6 @@
import cv2
class CannyDetector:
def __call__(self, img, low_threshold, high_threshold):
return cv2.Canny(img, low_threshold, high_threshold)

View File

@@ -0,0 +1,155 @@
import cv2
import numpy as np
from PIL import Image
import torch
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
stencil = {}
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(
input_image,
(W, H),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
)
return img
def controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt=1
):
channels = 3
if isinstance(controlnet_hint, torch.Tensor):
# torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
shape_chw = (channels, height, width)
shape_bchw = (1, channels, height, width)
shape_nchw = (num_images_per_prompt, channels, height, width)
if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
controlnet_hint = controlnet_hint.to(
dtype=torch.float32, device=torch.device("cpu")
)
if controlnet_hint.shape != shape_nchw:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `controlnet_hint` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
# hwc is opencv compatible image format. Color channel must be BGR Format.
if controlnet_hint.shape == (height, width):
controlnet_hint = np.repeat(
controlnet_hint[:, :, np.newaxis], channels, axis=2
) # hw -> hwc(c==3)
shape_hwc = (height, width, channels)
shape_bhwc = (1, height, width, channels)
shape_nhwc = (num_images_per_prompt, height, width, channels)
if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
controlnet_hint = torch.from_numpy(controlnet_hint.copy())
controlnet_hint = controlnet_hint.to(
dtype=torch.float32, device=torch.device("cpu")
)
controlnet_hint /= 255.0
if controlnet_hint.shape != shape_nhwc:
controlnet_hint = controlnet_hint.repeat(
num_images_per_prompt, 1, 1, 1
)
controlnet_hint = controlnet_hint.permute(
0, 3, 1, 2
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `controlnet_hint` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, Image.Image):
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable image size of `controlnet_hint` is ({width}, {height}) but is {controlnet_hint.size}"
)
else:
raise ValueError(
f"Acceptable type of `controlnet_hint` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
def controlnet_hint_conversion(
image, use_stencil, height, width, num_images_per_prompt=1
):
controlnet_hint = None
match use_stencil:
case "canny":
print("Detecting edge with canny")
controlnet_hint = hint_canny(image, width)
case _:
return None
controlnet_hint = controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt
)
return controlnet_hint
# Stencil 1. Canny
def hint_canny(
image: Image.Image,
width=512,
height=512,
low_threshold=100,
high_threshold=200,
):
with torch.no_grad():
input_image = np.array(image)
image_resolution = width
img = resize_image(HWC3(input_image), image_resolution)
if not "canny" in stencil:
stencil["canny"] = CannyDetector()
detected_map = stencil["canny"](img, low_threshold, high_threshold)
detected_map = HWC3(detected_map)
return detected_map