diff --git a/shark/examples/shark_inference/stable_diffusion/main.py b/shark/examples/shark_inference/stable_diffusion/main.py index b0fa3ac0..d58e3900 100644 --- a/shark/examples/shark_inference/stable_diffusion/main.py +++ b/shark/examples/shark_inference/stable_diffusion/main.py @@ -1,7 +1,7 @@ from transformers import CLIPTextModel, CLIPTokenizer import torch from PIL import Image -from diffusers import LMSDiscreteScheduler +from diffusers import LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler from tqdm.auto import tqdm import numpy as np from stable_args import args @@ -47,7 +47,7 @@ if __name__ == "__main__": batch_size = len(prompt) set_iree_runtime_flags() - unet = get_unet() + unet_lms, unet = get_unet() vae = get_vae() clip = get_clip() @@ -93,7 +93,7 @@ if __name__ == "__main__": scheduler.set_timesteps(num_inference_steps) scheduler.is_scale_input_called = True - latents = latents * scheduler.sigmas[0] + latents = latents * scheduler.init_noise_sigma text_embeddings_numpy = text_embeddings.detach().numpy() avg_ms = 0 @@ -101,29 +101,42 @@ if __name__ == "__main__": step_start = time.time() print(f"i = {i} t = {t}", end="") timestep = torch.tensor([t]).to(dtype).detach().numpy() - if args.precision == "int8": - timestep = np.array(t).astype("int64") latents_numpy = latents.detach().numpy() - sigma_numpy = np.array(scheduler.sigmas[i]).astype(np.float32) profile_device = start_profiling(file_path="unet.rdc") - noise_pred = unet.forward( - ( - latents_numpy, - timestep, - text_embeddings_numpy, - sigma_numpy, - guidance_scale, + + noise_pred = None + if isinstance(scheduler, LMSDiscreteScheduler): + sigma_numpy = np.array(scheduler.sigmas[i]).astype(np.float32) + noise_pred = unet_lms.forward( + ( + latents_numpy, + timestep, + text_embeddings_numpy, + sigma_numpy, + guidance_scale, + ) ) - ) + else: + noise_pred = unet.forward( + ( + latents_numpy, + timestep, + text_embeddings_numpy, + guidance_scale, + ) + ) + end_profiling(profile_device) + noise_pred = torch.from_numpy(noise_pred) step_time = time.time() - step_start avg_ms += step_time step_ms = int((step_time) * 1000) print(f" ({step_ms}ms)") - latents = scheduler.step(noise_pred, i, latents)["prev_sample"] + latents = scheduler.step(noise_pred, t, latents).prev_sample + avg_ms = 1000 * avg_ms / args.steps print(f"Average step time: {avg_ms}ms/it") diff --git a/shark/examples/shark_inference/stable_diffusion/model_wrappers.py b/shark/examples/shark_inference/stable_diffusion/model_wrappers.py index 3d920ae1..4daeb037 100644 --- a/shark/examples/shark_inference/stable_diffusion/model_wrappers.py +++ b/shark/examples/shark_inference/stable_diffusion/model_wrappers.py @@ -84,7 +84,7 @@ def get_vae16(model_name="vae_fp16", extra_args=[]): return shark_vae -def get_unet16_wrapped(model_name="unet_fp16_wrapped", extra_args=[]): +def get_unet16_lms(model_name="unet_fp16_wrapped", extra_args=[]): class UnetModel(torch.nn.Module): def __init__(self): super().__init__() @@ -135,7 +135,7 @@ def get_unet16_wrapped(model_name="unet_fp16_wrapped", extra_args=[]): return shark_unet -def get_unet32_wrapped(model_name="unet_fp32_wrapped", extra_args=[]): +def get_unet32_lms(model_name="unet_fp32_wrapped", extra_args=[]): class UnetModel(torch.nn.Module): def __init__(self): super().__init__() @@ -179,3 +179,90 @@ def get_unet32_wrapped(model_name="unet_fp32_wrapped", extra_args=[]): extra_args=extra_args, ) return shark_unet + + +def get_unet32(model_name="unet_fp32_wrapped", extra_args=[]): + class UnetModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.unet = UNet2DConditionModel.from_pretrained( + "CompVis/stable-diffusion-v1-4", + subfolder="unet", + use_auth_token=YOUR_TOKEN, + ) + self.in_channels = self.unet.in_channels + self.train(False) + + def forward(self, latent, timestep, text_embedding, guidance_scale): + latents = torch.cat([latent] * 2) + unet_out = self.unet.forward( + latents, timestep, text_embedding, return_dict=False + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + unet = UnetModel() + latent_model_input = torch.rand([BATCH_SIZE, 4, 64, 64]) + text_embeddings = torch.rand([2 * BATCH_SIZE, args.max_length, 768]) + guidance_scale = torch.tensor(1).to(torch.float32) + shark_unet = compile_through_fx( + unet, + ( + latent_model_input, + torch.tensor([1.0]), + text_embeddings, + guidance_scale, + ), + model_name=model_name, + extra_args=extra_args, + ) + return shark_unet + + +def get_unet16(model_name="unet_fp16", extra_args=[]): + class UnetModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.unet = UNet2DConditionModel.from_pretrained( + "CompVis/stable-diffusion-v1-4", + subfolder="unet", + use_auth_token=YOUR_TOKEN, + revision="fp16", + ) + self.in_channels = self.unet.in_channels + self.train(False) + + def forward(self, latent, timestep, text_embedding, guidance_scale): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latents = torch.cat([latent] * 2) + unet_out = self.unet.forward( + latents, timestep, text_embedding, return_dict=False + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + unet = UnetModel() + unet = unet.half().cuda() + latent_model_input = torch.rand([BATCH_SIZE, 4, 64, 64]).half().cuda() + text_embeddings = ( + torch.rand([2 * BATCH_SIZE, args.max_length, 768]).half().cuda() + ) + guidance_scale = torch.tensor(1).to(torch.float32) + shark_unet = compile_through_fx( + unet, + ( + latent_model_input, + torch.tensor([1.0]).half().cuda(), + text_embeddings, + guidance_scale, + ), + model_name=model_name, + extra_args=extra_args, + ) + return shark_unet diff --git a/shark/examples/shark_inference/stable_diffusion/opt_params.py b/shark/examples/shark_inference/stable_diffusion/opt_params.py index 7b4a6dbe..d0146e0e 100644 --- a/shark/examples/shark_inference/stable_diffusion/opt_params.py +++ b/shark/examples/shark_inference/stable_diffusion/opt_params.py @@ -2,8 +2,10 @@ import sys from model_wrappers import ( get_vae32, get_vae16, - get_unet16_wrapped, - get_unet32_wrapped, + get_unet16_lms, + get_unet16, + get_unet32_lms, + get_unet32, get_clipped_text, ) from stable_args import args @@ -28,28 +30,38 @@ def get_unet(): return get_shark_model(bucket, model_name, iree_flags) else: bucket = "gs://shark_tank/prashant_nod" - model_name = "unet_22nov_fp16" + model_name = "unet_23nov_fp16" + model_name_lms = model_name + "_lms" iree_flags += [ "--iree-flow-enable-padding-linalg-ops", "--iree-flow-linalg-ops-padding-size=32", "--iree-flow-enable-conv-nchw-to-nhwc-transform", ] if args.import_mlir: - return get_unet16_wrapped(model_name, iree_flags) - return get_shark_model(bucket, model_name, iree_flags) + return get_unet16_lms(model_name_lms, iree_flags), get_unet16( + model_name, iree_flags + ) + return get_shark_model( + bucket, model_name_lms, iree_flags + ), get_shark_model(bucket, model_name, iree_flags) # Tuned model is not present for `fp32` case. if args.precision == "fp32": bucket = "gs://shark_tank/prashant_nod" - model_name = "unet_22nov_fp32" + model_name = "unet_23nov_fp32" + model_name_lms = model_name + "_lms" iree_flags += [ "--iree-flow-enable-conv-nchw-to-nhwc-transform", "--iree-flow-enable-padding-linalg-ops", "--iree-flow-linalg-ops-padding-size=16", ] if args.import_mlir: - return get_unet32_wrapped(model_name, iree_flags) - return get_shark_model(bucket, model_name, iree_flags) + return get_unet32_lms(model_name + "_lms", iree_flags), get_unet32( + model_name, iree_flags + ) + return get_shark_model( + bucket, model_name_lms, iree_flags + ), get_shark_model(bucket, model_name, iree_flags) if args.precision == "int8": bucket = "gs://shark_tank/prashant_nod" @@ -58,12 +70,13 @@ def get_unet(): "--iree-flow-enable-padding-linalg-ops", "--iree-flow-linalg-ops-padding-size=32", ] - # TODO: Pass iree_flags to the exported model. - if args.import_mlir: - sys.exit( - "--import_mlir is not supported for the int8 model, try --no-import_mlir flag." - ) - return get_shark_model(bucket, model_name, iree_flags) + sys.exit("int8 model is currently in maintenance.") + # # TODO: Pass iree_flags to the exported model. + # if args.import_mlir: + # sys.exit( + # "--import_mlir is not supported for the int8 model, try --no-import_mlir flag." + # ) + # return get_shark_model(bucket, model_name, iree_flags) def get_vae():