[SD] Improve Stencil feature to handle general image sizes

-- Currently stencil feature works with 512x512 images only.
-- This commit relaxes this constraint and adds support for various
   image sizes.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-03-11 15:04:01 +00:00
committed by Abhishek Varma
parent 16ad7d57a3
commit 691030fbab
5 changed files with 55 additions and 51 deletions

View File

@@ -38,6 +38,23 @@ init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be > 384 and multiple of 8.
# This utility function performs the transformation on the input image before
# sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
if width < 384 or height < 384:
sys.exit("width and height should at least be 384")
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height
# Exposed to UI.
def img2img_inf(
prompt: str,
@@ -105,6 +122,7 @@ def img2img_inf(
if use_stencil is not None:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
@@ -236,6 +254,7 @@ if __name__ == "__main__":
print("Flag --img_path is required.")
exit()
image = Image.open(args.img_path).convert("RGB")
# When the models get uploaded, it should be default to False.
args.import_mlir = True
@@ -243,6 +262,7 @@ if __name__ == "__main__":
if use_stencil:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, args.width, args.height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
@@ -257,9 +277,7 @@ if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = Image.open(args.img_path).convert("RGB")
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model

View File

@@ -31,13 +31,23 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
mul_val = int(shape[i].split("*")[0])
if "batch_size" in shape[i]:
new_shape.append(batch_size * mul_val)
elif "height" in shape[i]:
new_shape.append(height * mul_val)
elif "width" in shape[i]:
new_shape.append(width * mul_val)
if "*" in shape[i]:
mul_val = int(shape[i].split("*")[0])
if "batch_size" in shape[i]:
new_shape.append(batch_size * mul_val)
elif "height" in shape[i]:
new_shape.append(height * mul_val)
elif "width" in shape[i]:
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
elif "height" in shape[i]:
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
else:
new_shape.append(shape[i])
return new_shape

View File

@@ -20,9 +20,6 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
StableDiffusionPipeline,
)
import cv2
from PIL import Image
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(

View File

@@ -110,7 +110,7 @@
"dtype": "f32"
},
"controlnet_hint": {
"shape": [1, 3, 512, 512],
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
}
},
@@ -143,55 +143,55 @@
"dtype": "f32"
},
"control1": {
"shape": [2, 320, 64, 64],
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control2": {
"shape": [2, 320, 64, 64],
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control3": {
"shape": [2, 320, 64, 64],
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"control4": {
"shape": [2, 320, 32, 32],
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"control5": {
"shape": [2, 640, 32, 32],
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control6": {
"shape": [2, 640, 32, 32],
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"control7": {
"shape": [2, 640, 16, 16],
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"control8": {
"shape": [2, 1280, 16, 16],
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control9": {
"shape": [2, 1280, 16, 16],
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"control10": {
"shape": [2, 1280, 8, 8],
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control11": {
"shape": [2, 1280, 8, 8],
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control12": {
"shape": [2, 1280, 8, 8],
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"control13": {
"shape": [2, 1280, 8, 8],
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},

View File

@@ -1,4 +1,3 @@
import cv2
import numpy as np
from PIL import Image
import torch
@@ -26,23 +25,6 @@ def HWC3(x):
return y
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(
input_image,
(W, H),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
)
return img
def controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt=1
):
@@ -125,7 +107,7 @@ def controlnet_hint_conversion(
match use_stencil:
case "canny":
print("Detecting edge with canny")
controlnet_hint = hint_canny(image, width)
controlnet_hint = hint_canny(image)
case _:
return None
controlnet_hint = controlnet_hint_shaping(
@@ -155,19 +137,16 @@ def get_stencil_model_id(use_stencil):
# Stencil 1. Canny
def hint_canny(
image: Image.Image,
width=512,
height=512,
low_threshold=100,
high_threshold=200,
):
with torch.no_grad():
input_image = np.array(image)
image_resolution = width
img = resize_image(HWC3(input_image), image_resolution)
if not "canny" in stencil:
stencil["canny"] = CannyDetector()
detected_map = stencil["canny"](img, low_threshold, high_threshold)
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
detected_map = HWC3(detected_map)
return detected_map