[SD] Improve Stencil feature to handle general image sizes

-- Currently stencil feature works with 512x512 images only.
-- This commit relaxes this constraint and adds support for various
   image sizes.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-03-11 15:04:01 +00:00
committed by Abhishek Varma
parent 16ad7d57a3
commit 691030fbab
5 changed files with 55 additions and 51 deletions

View File

@@ -38,6 +38,23 @@ init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# Both width and height should be > 384 and multiple of 8.
# This utility function performs the transformation on the input image before
# sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
if width < 384 or height < 384:
sys.exit("width and height should at least be 384")
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
new_image = image.resize((n_width, n_height))
return new_image, n_width, n_height
# Exposed to UI.
def img2img_inf(
prompt: str,
@@ -105,6 +122,7 @@ def img2img_inf(
if use_stencil is not None:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
@@ -236,6 +254,7 @@ if __name__ == "__main__":
print("Flag --img_path is required.")
exit()
image = Image.open(args.img_path).convert("RGB")
# When the models get uploaded, it should be default to False.
args.import_mlir = True
@@ -243,6 +262,7 @@ if __name__ == "__main__":
if use_stencil:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, args.width, args.height = resize_stencil(image)
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
@@ -257,9 +277,7 @@ if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = Image.open(args.img_path).convert("RGB")
seed = utils.sanitize_seed(args.seed)
# Adjust for height and width based on model