mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
[WIP] Add ControlNet to SD pipeline
-- This commit adds ControlNet to SD pipeline. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
11
apps/README.md
Normal file
11
apps/README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# [WIP] ControlNet integration:
|
||||
* Since this is dependent on `diffusers` and currently there's a [WIP draft PR](https://github.com/huggingface/diffusers/pull/2407), we'd need to clone [takuma104/diffusers](https://github.com/takuma104/diffusers/), checkout `controlnet` [branch](https://github.com/takuma104/diffusers/tree/controlnet) and build `diffusers` as mentioned [here](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md).
|
||||
* NOTE: Run `pip uninstall diffusers` before building it from scratch in the previous step.
|
||||
* Currently we've included `ControlNet` (Canny feature) as part of our SharkStableDiffusion pipeline.
|
||||
* To test it you'd first need to download [control_sd15_canny.pth](https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_canny.pth) and pass it through [convert_controlnet_to_diffusers.py](https://github.com/huggingface/diffusers/blob/faf1cfbe826c88366524e92fa27b2104effdb8c4/scripts/convert_controlnet_to_diffusers.py).
|
||||
* You then need to modify [this](https://github.com/nod-ai/SHARK/blob/5e5c86f4893bfdbfc2ca310803beb4ef7146213f/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py#L62) line to point to the extracted `diffusers` checkpoint directory.
|
||||
* Download [sample image](https://drive.google.com/file/d/13ncuKjTO-reK-nBcLM1Lzuli5WPPB38v/view?usp=sharing) and provide the path [here](https://github.com/nod-ai/SHARK/blob/5e5c86f4893bfdbfc2ca310803beb4ef7146213f/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py#L152).
|
||||
* Finally run
|
||||
```shell
|
||||
python apps/stable_diffusion/scripts/txt2img.py --precision=fp32 --device=cuda --prompt="bird" --max_length=64 --import_mlir --enable_stack_trace --no-use_tuned --hf_model_id="runwayml/stable-diffusion-v1-5" --scheduler="DDIM"
|
||||
```
|
||||
@@ -227,13 +227,17 @@ class SharkifyStableDiffusionModel:
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
# TODO: Instead of flattening the `control` try to use the list.
|
||||
def forward(
|
||||
self, latent, timestep, text_embedding, guidance_scale
|
||||
self, latent, timestep, text_embedding, guidance_scale,
|
||||
control1, control2, control3, control4, control5, control6, control7,
|
||||
control8, control9, control10, control11, control12, control13
|
||||
):
|
||||
control = [control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12, control13]
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
latents, timestep, text_embedding, return_dict=False
|
||||
latents, timestep, text_embedding, control=control, return_dict=False
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
|
||||
@@ -19,6 +19,9 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
@@ -65,6 +68,66 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# TODO: Move it to a separate ControlNet based specific model functions.
|
||||
def HWC3(self, x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
assert x.ndim == 3
|
||||
H, W, C = x.shape
|
||||
assert C == 1 or C == 3 or C == 4
|
||||
if C == 3:
|
||||
return x
|
||||
if C == 1:
|
||||
return np.concatenate([x, x, x], axis=2)
|
||||
if C == 4:
|
||||
color = x[:, :, 0:3].astype(np.float32)
|
||||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||
y = y.clip(0, 255).astype(np.uint8)
|
||||
return y
|
||||
|
||||
# TODO: Move it to a separate ControlNet based specific model functions.
|
||||
def resize_image(self, input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(
|
||||
input_image,
|
||||
(W, H),
|
||||
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
|
||||
)
|
||||
return img
|
||||
|
||||
# TODO: Move it to a separate ControlNet based specific model functions.
|
||||
def hint_canny(
|
||||
self,
|
||||
image: Image.Image,
|
||||
width=512,
|
||||
height=512,
|
||||
low_threshold=100,
|
||||
high_threshold=200,
|
||||
):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
image_resolution = width
|
||||
|
||||
img = self.resize_image(self.HWC3(input_image), image_resolution)
|
||||
|
||||
class CannyDetector:
|
||||
def __call__(self, img, low_threshold, high_threshold):
|
||||
return cv2.Canny(img, low_threshold, high_threshold)
|
||||
|
||||
canny = CannyDetector()
|
||||
detected_map = canny(img, low_threshold, high_threshold)
|
||||
detected_map = self.HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
@@ -80,6 +143,18 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. `controlnet_hint`.
|
||||
# 2. Change `num_images_per_prompt`.
|
||||
# 3. Supply `controlnet_img`.
|
||||
from PIL import Image
|
||||
|
||||
controlnet_img = Image.open("/home/abhishek/cn_1.png")
|
||||
controlnet_hint = self.hint_canny(controlnet_img)
|
||||
controlnet_hint = self.controlnet_hint_conversion(
|
||||
controlnet_hint, height, width, num_images_per_prompt=1
|
||||
)
|
||||
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
@@ -122,6 +197,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
controlnet_hint=controlnet_hint,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
|
||||
@@ -55,6 +55,12 @@ class StableDiffusionPipeline:
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
self.log = ""
|
||||
# TODO: Make this dynamic like other models which'll be passed to StableDiffusionPipeline.
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
self.controlnet = UNet2DConditionModel.from_pretrained(
|
||||
"/home/abhishek/weights/canny_weight", subfolder="controlnet"
|
||||
)
|
||||
|
||||
def encode_prompts(self, prompts, neg_prompts, max_length):
|
||||
# Tokenize text and get embeddings
|
||||
@@ -116,6 +122,7 @@ class StableDiffusionPipeline:
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
controlnet_hint=None,
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
@@ -126,7 +133,7 @@ class StableDiffusionPipeline:
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, t)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
@@ -142,16 +149,56 @@ class StableDiffusionPipeline:
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
if controlnet_hint is not None:
|
||||
if not torch.is_tensor(latent_model_input):
|
||||
latent_model_input_1 = torch.from_numpy(
|
||||
np.asarray(latent_model_input)
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
control = self.controlnet(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
encoder_hidden_states=text_embeddings,
|
||||
controlnet_hint=controlnet_hint,
|
||||
)
|
||||
timestep = timestep.detach().numpy()
|
||||
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
timestep = timestep.detach().numpy()
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
@@ -177,6 +224,80 @@ class StableDiffusionPipeline:
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def controlnet_hint_conversion(
|
||||
self, controlnet_hint, height, width, num_images_per_prompt=1
|
||||
):
|
||||
channels = 3
|
||||
if isinstance(controlnet_hint, torch.Tensor):
|
||||
# torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
|
||||
shape_chw = (channels, height, width)
|
||||
shape_bchw = (1, channels, height, width)
|
||||
shape_nchw = (num_images_per_prompt, channels, height, width)
|
||||
if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
|
||||
controlnet_hint = controlnet_hint.to(
|
||||
dtype=torch.float32, device=torch.device("cpu")
|
||||
)
|
||||
if controlnet_hint.shape != shape_nchw:
|
||||
controlnet_hint = controlnet_hint.repeat(
|
||||
num_images_per_prompt, 1, 1, 1
|
||||
)
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `controlnet_hint` are any of ({channels}, {height}, {width}),"
|
||||
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
|
||||
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
)
|
||||
elif isinstance(controlnet_hint, np.ndarray):
|
||||
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
|
||||
# hwc is opencv compatible image format. Color channel must be BGR Format.
|
||||
if controlnet_hint.shape == (height, width):
|
||||
controlnet_hint = np.repeat(
|
||||
controlnet_hint[:, :, np.newaxis], channels, axis=2
|
||||
) # hw -> hwc(c==3)
|
||||
shape_hwc = (height, width, channels)
|
||||
shape_bhwc = (1, height, width, channels)
|
||||
shape_nhwc = (num_images_per_prompt, height, width, channels)
|
||||
if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
|
||||
controlnet_hint = torch.from_numpy(controlnet_hint.copy())
|
||||
controlnet_hint = controlnet_hint.to(
|
||||
dtype=torch.float32, device=torch.device("cpu")
|
||||
)
|
||||
controlnet_hint /= 255.0
|
||||
if controlnet_hint.shape != shape_nhwc:
|
||||
controlnet_hint = controlnet_hint.repeat(
|
||||
num_images_per_prompt, 1, 1, 1
|
||||
)
|
||||
controlnet_hint = controlnet_hint.permute(
|
||||
0, 3, 1, 2
|
||||
) # b h w c -> b c h w
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `controlnet_hint` are any of ({width}, {channels}), "
|
||||
+ f"({height}, {width}, {channels}), "
|
||||
+ f"(1, {height}, {width}, {channels}) or "
|
||||
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
)
|
||||
elif isinstance(controlnet_hint, Image.Image):
|
||||
if controlnet_hint.size == (width, height):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
controlnet_hint = np.array(controlnet_hint) # to numpy
|
||||
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
|
||||
return self.controlnet_hint_conversion(
|
||||
controlnet_hint, height, width, num_images_per_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable image size of `controlnet_hint` is ({width}, {height}) but is {controlnet_hint.size}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable type of `controlnet_hint` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
|
||||
@@ -83,6 +83,58 @@
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control1": {
|
||||
"shape": [1, 320, 64, 64],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control2": {
|
||||
"shape": [1, 320, 64, 64],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control3": {
|
||||
"shape": [1, 320, 64, 64],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control4": {
|
||||
"shape": [1, 320, 32, 32],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control5": {
|
||||
"shape": [1, 640, 32, 32],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control6": {
|
||||
"shape": [1, 640, 32, 32],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control7": {
|
||||
"shape": [1, 640, 16, 16],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control8": {
|
||||
"shape": [1, 1280, 16, 16],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control9": {
|
||||
"shape": [1, 1280, 16, 16],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control10": {
|
||||
"shape": [1, 1280, 8, 8],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control11": {
|
||||
"shape": [1, 1280, 8, 8],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control12": {
|
||||
"shape": [1, 1280, 8, 8],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control13": {
|
||||
"shape": [1, 1280, 8, 8],
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
|
||||
Reference in New Issue
Block a user