mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
[SD] Improve Stencil feature to handle general image sizes
-- Currently stencil feature works with 512x512 images only. -- This commit relaxes this constraint and adds support for various image sizes. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
committed by
Abhishek Varma
parent
16ad7d57a3
commit
691030fbab
@@ -38,6 +38,23 @@ init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
# For stencil, the input image can be of any size but we need to ensure that
|
||||
# it conforms with our model contraints :-
|
||||
# Both width and height should be > 384 and multiple of 8.
|
||||
# This utility function performs the transformation on the input image before
|
||||
# sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image):
|
||||
width, height = image.size
|
||||
if width < 384 or height < 384:
|
||||
sys.exit("width and height should at least be 384")
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
new_image = image.resize((n_width, n_height))
|
||||
return new_image, n_width, n_height
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def img2img_inf(
|
||||
prompt: str,
|
||||
@@ -105,6 +122,7 @@ def img2img_inf(
|
||||
if use_stencil is not None:
|
||||
args.scheduler = "DDIM"
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, width, height = resize_stencil(image)
|
||||
elif args.scheduler != "PNDM":
|
||||
if "Shark" in args.scheduler:
|
||||
print(
|
||||
@@ -236,6 +254,7 @@ if __name__ == "__main__":
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
image = Image.open(args.img_path).convert("RGB")
|
||||
# When the models get uploaded, it should be default to False.
|
||||
args.import_mlir = True
|
||||
|
||||
@@ -243,6 +262,7 @@ if __name__ == "__main__":
|
||||
if use_stencil:
|
||||
args.scheduler = "DDIM"
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, args.width, args.height = resize_stencil(image)
|
||||
elif args.scheduler != "PNDM":
|
||||
if "Shark" in args.scheduler:
|
||||
print(
|
||||
@@ -257,9 +277,7 @@ if __name__ == "__main__":
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
image = Image.open(args.img_path).convert("RGB")
|
||||
seed = utils.sanitize_seed(args.seed)
|
||||
# Adjust for height and width based on model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user