mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
This commit updates VAE model which significantly improves performance by an order of ~300ms. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
159 lines
5.0 KiB
Python
159 lines
5.0 KiB
Python
import torch
|
|
from PIL import Image
|
|
import torchvision.transforms as T
|
|
from tqdm.auto import tqdm
|
|
from models.stable_diffusion.cache_objects import (
|
|
cache_obj,
|
|
schedulers,
|
|
)
|
|
from models.stable_diffusion.stable_args import args
|
|
from random import randint
|
|
import numpy as np
|
|
import time
|
|
|
|
|
|
def set_ui_params(prompt, negative_prompt, steps, guidance_scale, seed):
|
|
args.prompts = [prompt]
|
|
args.negative_prompts = [negative_prompt]
|
|
args.steps = steps
|
|
args.guidance_scale = guidance_scale
|
|
args.seed = seed
|
|
|
|
|
|
def stable_diff_inf(
|
|
prompt: str,
|
|
negative_prompt: str,
|
|
steps: int,
|
|
guidance_scale: float,
|
|
seed: int,
|
|
scheduler_key: str,
|
|
):
|
|
|
|
# Handle out of range seeds.
|
|
uint32_info = np.iinfo(np.uint32)
|
|
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
|
if seed < uint32_min or seed >= uint32_max:
|
|
seed = randint(uint32_min, uint32_max)
|
|
|
|
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
|
set_ui_params(prompt, negative_prompt, steps, guidance_scale, seed)
|
|
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
|
generator = torch.manual_seed(
|
|
args.seed
|
|
) # Seed generator to create the inital latent noise
|
|
|
|
# set height and width.
|
|
height = 512 # default height of Stable Diffusion
|
|
width = 512 # default width of Stable Diffusion
|
|
if args.version == "v2.1":
|
|
height = 768
|
|
width = 768
|
|
|
|
# create a random initial latent.
|
|
latents = torch.randn(
|
|
(1, 4, height // 8, width // 8),
|
|
generator=generator,
|
|
dtype=torch.float32,
|
|
).to(dtype)
|
|
|
|
# Initialize vae and unet models.
|
|
vae, unet, clip, tokenizer = (
|
|
cache_obj["vae"],
|
|
cache_obj["unet"],
|
|
cache_obj["clip"],
|
|
cache_obj["tokenizer"],
|
|
)
|
|
scheduler = schedulers[scheduler_key]
|
|
cpu_scheduling = not scheduler_key.startswith("Shark")
|
|
|
|
start = time.time()
|
|
text_input = tokenizer(
|
|
args.prompts,
|
|
padding="max_length",
|
|
max_length=args.max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
max_length = text_input.input_ids.shape[-1]
|
|
uncond_input = tokenizer(
|
|
args.negative_prompts,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
|
|
|
|
clip_inf_start = time.time()
|
|
text_embeddings = clip.forward((text_input,))
|
|
clip_inf_end = time.time()
|
|
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
|
text_embeddings_numpy = text_embeddings.detach().numpy()
|
|
|
|
scheduler.set_timesteps(args.steps)
|
|
scheduler.is_scale_input_called = True
|
|
|
|
latents = latents * scheduler.init_noise_sigma
|
|
|
|
avg_ms = 0
|
|
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
|
|
|
step_start = time.time()
|
|
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
|
latent_model_input = scheduler.scale_model_input(latents, t)
|
|
if cpu_scheduling:
|
|
latent_model_input = latent_model_input.detach().numpy()
|
|
|
|
noise_pred = unet.forward(
|
|
(
|
|
latent_model_input,
|
|
timestep,
|
|
text_embeddings_numpy,
|
|
args.guidance_scale,
|
|
),
|
|
send_to_host=False,
|
|
)
|
|
|
|
if cpu_scheduling:
|
|
noise_pred = torch.from_numpy(noise_pred.to_host())
|
|
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
|
else:
|
|
latents = scheduler.step(noise_pred, t, latents)
|
|
step_time = time.time() - step_start
|
|
avg_ms += step_time
|
|
step_ms = int((step_time) * 1000)
|
|
if not args.hide_steps:
|
|
print(f" \nIteration = {i}, Time = {step_ms}ms")
|
|
|
|
# scale and decode the image latents with vae
|
|
latents_numpy = latents
|
|
if cpu_scheduling:
|
|
latents_numpy = latents.detach().numpy()
|
|
vae_start = time.time()
|
|
images = vae.forward((latents_numpy,))
|
|
vae_end = time.time()
|
|
end_time = time.time()
|
|
|
|
avg_ms = 1000 * avg_ms / args.steps
|
|
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
|
|
vae_inf_time = (vae_end - vae_start) * 1000
|
|
total_time = end_time - start
|
|
print(f"\nAverage step time: {avg_ms}ms/it")
|
|
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
|
|
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
|
|
print(f"\nTotal image generation time: {total_time}sec")
|
|
|
|
# generate outputs to web.
|
|
transform = T.ToPILImage()
|
|
pil_images = [
|
|
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
|
|
]
|
|
|
|
text_output = f"prompt={args.prompts}"
|
|
text_output += f"\nnegative prompt={args.negative_prompts}"
|
|
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
|
|
text_output += f"\nAverage step time: {avg_ms:.2f}ms/it"
|
|
text_output += f"\nTotal image generation time: {total_time:.2f}sec"
|
|
|
|
return pil_images[0], text_output
|