mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD] Move initial latent generation out of inference time
The initial random latent generation is not taken into account for total SD inference time. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -70,6 +70,13 @@ if __name__ == "__main__":
|
||||
if batch_size != len(neg_prompt):
|
||||
sys.exit("prompts and negative prompts must be of same length")
|
||||
|
||||
# create a random initial latent.
|
||||
latents = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
set_iree_runtime_flags()
|
||||
unet = get_unet()
|
||||
vae = get_vae()
|
||||
@@ -136,21 +143,15 @@ if __name__ == "__main__":
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler.is_scale_input_called = True
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
avg_ms = 0
|
||||
|
||||
avg_ms = 0
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps), disable=args.hide_steps):
|
||||
step_start = time.time()
|
||||
if args.hide_steps == False:
|
||||
if not args.hide_steps:
|
||||
print(f"i = {i} t = {t}", end="")
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
@@ -179,12 +180,9 @@ if __name__ == "__main__":
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
if args.hide_steps == False:
|
||||
if not args.hide_steps:
|
||||
print(f" ({step_ms}ms)")
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
print(f"Average step time: {avg_ms}ms/it")
|
||||
|
||||
latents_numpy = latents
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
@@ -195,8 +193,10 @@ if __name__ == "__main__":
|
||||
end_profiling(profile_device)
|
||||
total_end = time.time()
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
|
||||
vae_inf_time = (vae_end - vae_start) * 1000
|
||||
print(f"Average step time: {avg_ms}ms/it")
|
||||
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
|
||||
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
print(f"Total image generation runtime (s): {total_end - start:.4f}")
|
||||
|
||||
@@ -48,6 +48,13 @@ def stable_diff_inf(
|
||||
height = 768
|
||||
width = 768
|
||||
|
||||
# create a random initial latent.
|
||||
latents = torch.randn(
|
||||
(1, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
# Initialize vae and unet models.
|
||||
vae, unet, clip, tokenizer = (
|
||||
cache_obj["vae"],
|
||||
@@ -82,19 +89,12 @@ def stable_diff_inf(
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
latents = torch.randn(
|
||||
(1, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
scheduler.set_timesteps(args.steps)
|
||||
scheduler.is_scale_input_called = True
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
avg_ms = 0
|
||||
out_img = None
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
step_start = time.time()
|
||||
@@ -121,7 +121,8 @@ def stable_diff_inf(
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
print(f" \nIteration = {i}, Time = {step_ms}ms")
|
||||
if not args.hide_steps:
|
||||
print(f" \nIteration = {i}, Time = {step_ms}ms")
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
@@ -134,24 +135,23 @@ def stable_diff_inf(
|
||||
image = torch.from_numpy(image)
|
||||
image = (image.detach().cpu().permute(0, 2, 3, 1) * 255.0).numpy()
|
||||
images = image.round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
out_img = pil_images[0]
|
||||
end_time = time.time()
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
total_time = time.time() - start
|
||||
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
|
||||
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
|
||||
print(f"\nAverage step time: {avg_ms}ms/it")
|
||||
text_output += "\nTotal image generation time: {0:.2f}sec".format(
|
||||
total_time
|
||||
)
|
||||
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
|
||||
vae_inf_time = (vae_end - vae_start) * 1000
|
||||
total_time = end_time - start
|
||||
print(f"\nAverage step time: {avg_ms}ms/it")
|
||||
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
|
||||
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
print(f"\nTotal image generation time: {total_time}sec")
|
||||
|
||||
return out_img, text_output
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
|
||||
text_output += f"\nAverage step time: {avg_ms:.2f}ms/it"
|
||||
text_output += f"\nTotal image generation time: {total_time:.2f}sec"
|
||||
|
||||
return pil_images[0], text_output
|
||||
|
||||
@@ -163,4 +163,12 @@ p.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for inserting debug frames between iterations for use with rgp.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hide_steps",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for hiding the details of iteration/sec for each step.",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user