add strength param

This commit is contained in:
PhaneeshB
2023-02-10 19:23:39 +05:30
committed by Phaneesh Barwaria
parent 1ce02e365d
commit 0430c741c6
3 changed files with 37 additions and 3 deletions

View File

@@ -77,7 +77,7 @@ def save_output_img(output_img):
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.hf_model_id}",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Strength: {args.strength}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.hf_model_id}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
@@ -98,6 +98,7 @@ def save_output_img(output_img):
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"STRENGTH": args.strength,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
@@ -129,6 +130,7 @@ def img2img_inf(
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: int,
batch_count: int,
@@ -151,6 +153,7 @@ def img2img_inf(
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.strength = strength
args.scheduler = scheduler
args.img_path = init_image
image = Image.open(args.img_path)
@@ -240,6 +243,7 @@ def img2img_inf(
height,
width,
steps,
strength,
guidance_scale,
seed,
args.max_length,
@@ -275,6 +279,15 @@ if __name__ == "__main__":
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
if args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
)
args.scheduler = "PNDM"
print(
f"[WARNING] Img2Img works best with PNDM scheduler. Please use that"
)
scheduler_obj = schedulers[args.scheduler]
image = Image.open(args.img_path)
@@ -304,6 +317,7 @@ if __name__ == "__main__":
args.height,
args.width,
args.steps,
args.strength,
args.guidance_scale,
args.seed,
args.max_length,

View File

@@ -50,6 +50,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
width,
generator,
num_inference_steps,
strength,
dtype,
):
# Pre process image -> get image encoded -> process latents
@@ -57,7 +58,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# TODO: process with variable HxW combos
# Pre process image
image = image.resize((height, width)) # Current support for 512x512
image = image.resize((width, height))
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
@@ -69,9 +70,20 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# set scheduler steps
self.scheduler.set_timesteps(num_inference_steps)
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
# timesteps reduced as per strength
timesteps = self.scheduler.timesteps[t_start:]
# new number of steps to be used as per strength will be
# num_inference_steps = num_inference_steps - t_start
# add noise to data
latents = latents * self.scheduler.init_noise_sigma
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
)
return latents
@@ -92,6 +104,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
height,
width,
num_inference_steps,
strength,
guidance_scale,
seed,
max_length,
@@ -130,6 +143,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
)

View File

@@ -91,6 +91,12 @@ p.add_argument(
help="max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="the strength of change applied on the given input image for img2img",
)
##############################################################################
### Model Config and Usage Params
##############################################################################