[SD] Move initial latent generation out of inference time

The initial random latent generation is not taken into account
for total SD inference time.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-12-19 21:41:58 +05:30
parent 3173b7d1d9
commit b2b3a0a62b
3 changed files with 41 additions and 33 deletions

View File

@@ -70,6 +70,13 @@ if __name__ == "__main__":
if batch_size != len(neg_prompt):
sys.exit("prompts and negative prompts must be of same length")
# create a random initial latent.
latents = torch.randn(
(batch_size, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
set_iree_runtime_flags()
unet = get_unet()
vae = get_vae()
@@ -136,21 +143,15 @@ if __name__ == "__main__":
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
latents = torch.randn(
(batch_size, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
scheduler.set_timesteps(num_inference_steps)
scheduler.is_scale_input_called = True
latents = latents * scheduler.init_noise_sigma
avg_ms = 0
avg_ms = 0
for i, t in tqdm(enumerate(scheduler.timesteps), disable=args.hide_steps):
step_start = time.time()
if args.hide_steps == False:
if not args.hide_steps:
print(f"i = {i} t = {t}", end="")
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = scheduler.scale_model_input(latents, t)
@@ -179,12 +180,9 @@ if __name__ == "__main__":
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
if args.hide_steps == False:
if not args.hide_steps:
print(f" ({step_ms}ms)")
avg_ms = 1000 * avg_ms / args.steps
print(f"Average step time: {avg_ms}ms/it")
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
@@ -195,8 +193,10 @@ if __name__ == "__main__":
end_profiling(profile_device)
total_end = time.time()
avg_ms = 1000 * avg_ms / args.steps
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
vae_inf_time = (vae_end - vae_start) * 1000
print(f"Average step time: {avg_ms}ms/it")
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
print(f"Total image generation runtime (s): {total_end - start:.4f}")

View File

@@ -48,6 +48,13 @@ def stable_diff_inf(
height = 768
width = 768
# create a random initial latent.
latents = torch.randn(
(1, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
# Initialize vae and unet models.
vae, unet, clip, tokenizer = (
cache_obj["vae"],
@@ -82,19 +89,12 @@ def stable_diff_inf(
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
latents = torch.randn(
(1, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
scheduler.set_timesteps(args.steps)
scheduler.is_scale_input_called = True
latents = latents * scheduler.init_noise_sigma
avg_ms = 0
out_img = None
for i, t in tqdm(enumerate(scheduler.timesteps)):
step_start = time.time()
@@ -121,7 +121,8 @@ def stable_diff_inf(
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
print(f" \nIteration = {i}, Time = {step_ms}ms")
if not args.hide_steps:
print(f" \nIteration = {i}, Time = {step_ms}ms")
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
@@ -134,24 +135,23 @@ def stable_diff_inf(
image = torch.from_numpy(image)
image = (image.detach().cpu().permute(0, 2, 3, 1) * 255.0).numpy()
images = image.round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
out_img = pil_images[0]
end_time = time.time()
avg_ms = 1000 * avg_ms / args.steps
total_time = time.time() - start
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
print(f"\nAverage step time: {avg_ms}ms/it")
text_output += "\nTotal image generation time: {0:.2f}sec".format(
total_time
)
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
vae_inf_time = (vae_end - vae_start) * 1000
total_time = end_time - start
print(f"\nAverage step time: {avg_ms}ms/it")
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
print(f"\nTotal image generation time: {total_time}sec")
return out_img, text_output
pil_images = [Image.fromarray(image) for image in images]
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
text_output += f"\nAverage step time: {avg_ms:.2f}ms/it"
text_output += f"\nTotal image generation time: {total_time:.2f}sec"
return pil_images[0], text_output

View File

@@ -163,4 +163,12 @@ p.add_argument(
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
args = p.parse_args()