diff --git a/web/index.py b/web/index.py index ce4eee43..8e8f62dc 100644 --- a/web/index.py +++ b/web/index.py @@ -99,6 +99,11 @@ with gr.Blocks(css=demo_css) as shark_web: step=0.1, label="Guidance Scale", ) + version = gr.Radio( + label="Version", + value="v1.4", + choices=["v1.4", "v2.1base"], + ) with gr.Row(): scheduler_key = gr.Dropdown( label="Scheduler", @@ -108,6 +113,7 @@ with gr.Blocks(css=demo_css) as shark_web: "PNDM", "LMSDiscrete", "DPMSolverMultistep", + "EulerDiscrete", ], ) with gr.Group(): @@ -151,6 +157,7 @@ with gr.Blocks(css=demo_css) as shark_web: guidance, seed, scheduler_key, + version, ], outputs=[generated_img, std_output], ) @@ -162,6 +169,7 @@ with gr.Blocks(css=demo_css) as shark_web: guidance, seed, scheduler_key, + version, ], outputs=[generated_img, std_output], ) diff --git a/web/models/stable_diffusion/cache_objects.py b/web/models/stable_diffusion/cache_objects.py index 79f11911..ed800e97 100644 --- a/web/models/stable_diffusion/cache_objects.py +++ b/web/models/stable_diffusion/cache_objects.py @@ -4,6 +4,7 @@ from diffusers import ( PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler, + EulerDiscreteScheduler, ) from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip from models.stable_diffusion.utils import set_iree_runtime_flags @@ -27,9 +28,38 @@ schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained( "CompVis/stable-diffusion-v1-4", subfolder="scheduler", ) +schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( + "CompVis/stable-diffusion-v1-4", + subfolder="scheduler", +) + +schedulers2 = dict() +schedulers2["PNDM"] = PNDMScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", + subfolder="scheduler", +) +schedulers2["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", + subfolder="scheduler", +) +schedulers2["DDIM"] = DDIMScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", + subfolder="scheduler", +) +schedulers2[ + "DPMSolverMultistep" +] = DPMSolverMultistepScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", + subfolder="scheduler", +) +schedulers2["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", + subfolder="scheduler", +) # set iree-runtime flags set_iree_runtime_flags(args) +args.version = "v1.4" cache_obj = dict() @@ -44,3 +74,16 @@ cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained( cache_obj["unet"], cache_obj["clip"], ) = (get_vae(args), get_unet(args), get_clip(args)) + +args.version = "v2.1base" +# cache tokenizer +cache_obj["tokenizer2"] = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer" +) + +# cache vae, unet and clip. +( + cache_obj["vae2"], + cache_obj["unet2"], + cache_obj["clip2"], +) = (get_vae(args), get_unet(args), get_clip(args)) diff --git a/web/models/stable_diffusion/main.py b/web/models/stable_diffusion/main.py index 2d6b6d4f..6fe78c74 100644 --- a/web/models/stable_diffusion/main.py +++ b/web/models/stable_diffusion/main.py @@ -1,19 +1,24 @@ import torch from PIL import Image from tqdm.auto import tqdm -from models.stable_diffusion.cache_objects import cache_obj, schedulers +from models.stable_diffusion.cache_objects import ( + cache_obj, + schedulers, + schedulers2, +) from models.stable_diffusion.stable_args import args from random import randint import numpy as np import time -def set_ui_params(prompt, steps, guidance, seed, scheduler_key): +def set_ui_params(prompt, steps, guidance, seed, scheduler_key, version): args.prompt = [prompt] args.steps = steps args.guidance = guidance args.seed = seed args.scheduler = scheduler_key + args.version = version def stable_diff_inf( @@ -22,6 +27,7 @@ def stable_diff_inf( guidance: float, seed: int, scheduler_key: str, + version: str, ): # Handle out of range seeds. @@ -30,20 +36,29 @@ def stable_diff_inf( if seed < uint32_min or seed >= uint32_max: seed = randint(uint32_min, uint32_max) - set_ui_params(prompt, steps, guidance, seed, scheduler_key) + set_ui_params(prompt, steps, guidance, seed, scheduler_key, version) dtype = torch.float32 if args.precision == "fp32" else torch.half generator = torch.manual_seed( args.seed ) # Seed generator to create the inital latent noise guidance_scale = torch.tensor(args.guidance).to(torch.float32) # Initialize vae and unet models. - vae, unet, clip, tokenizer = ( - cache_obj["vae"], - cache_obj["unet"], - cache_obj["clip"], - cache_obj["tokenizer"], - ) - scheduler = schedulers[args.scheduler] + if args.version == "v2.1base": + vae, unet, clip, tokenizer = ( + cache_obj["vae2"], + cache_obj["unet2"], + cache_obj["clip2"], + cache_obj["tokenizer2"], + ) + scheduler = schedulers2[args.scheduler] + else: + vae, unet, clip, tokenizer = ( + cache_obj["vae"], + cache_obj["unet"], + cache_obj["clip"], + cache_obj["tokenizer"], + ) + scheduler = schedulers[args.scheduler] start = time.time() text_input = tokenizer( @@ -125,7 +140,7 @@ def stable_diff_inf( total_time = time.time() - start text_output = f"prompt={args.prompt}" - text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}" + text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}, version={args.version}" text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms) print(f"\nAverage step time: {avg_ms}ms/it") text_output += "\nTotal image generation time: {0:.2f}sec".format( diff --git a/web/models/stable_diffusion/model_wrappers.py b/web/models/stable_diffusion/model_wrappers.py index 6b07f593..da454e87 100644 --- a/web/models/stable_diffusion/model_wrappers.py +++ b/web/models/stable_diffusion/model_wrappers.py @@ -3,7 +3,6 @@ from transformers import CLIPTextModel from models.stable_diffusion.utils import compile_through_fx import torch - model_config = { "v2": "stabilityai/stable-diffusion-2", "v1.4": "CompVis/stable-diffusion-v1-4", @@ -34,9 +33,14 @@ model_input = { def get_clip_mlir(args, model_name="clip_text", extra_args=[]): + text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14" ) + if args.version == "v2": + text_encoder = CLIPTextModel.from_pretrained( + model_config[args.version], subfolder="text_encoder" + ) class CLIPText(torch.nn.Module): def __init__(self): @@ -58,13 +62,16 @@ def get_clip_mlir(args, model_name="clip_text", extra_args=[]): def get_vae_mlir(args, model_name="vae", extra_args=[]): + # revision param for from_pretrained defaults to "main" => fp32 + model_revision = "fp16" if args.precision == "fp16" else "main" + class VaeModel(torch.nn.Module): def __init__(self): super().__init__() self.vae = AutoencoderKL.from_pretrained( model_config[args.version], subfolder="vae", - revision="fp16", + revision=model_revision, ) def forward(self, input): @@ -72,10 +79,17 @@ def get_vae_mlir(args, model_name="vae", extra_args=[]): return (x / 2 + 0.5).clamp(0, 1) vae = VaeModel() - vae = vae.half().cuda() - inputs = tuple( - [inputs.half().cuda() for inputs in model_input[args.version]["vae"]] - ) + if args.precision == "fp16": + vae = vae.half().cuda() + inputs = tuple( + [ + inputs.half().cuda() + for inputs in model_input[args.version]["vae"] + ] + ) + else: + inputs = model_input[args.version]["vae"] + shark_vae = compile_through_fx( args, vae, @@ -116,13 +130,15 @@ def get_vae_encode_mlir(args, model_name="vae_encode", extra_args=[]): def get_unet_mlir(args, model_name="unet", extra_args=[]): + model_revision = "fp16" if args.precision == "fp16" else "main" + class UnetModel(torch.nn.Module): def __init__(self): super().__init__() self.unet = UNet2DConditionModel.from_pretrained( model_config[args.version], subfolder="unet", - revision="fp16", + revision=model_revision, ) self.in_channels = self.unet.in_channels self.train(False) @@ -140,13 +156,16 @@ def get_unet_mlir(args, model_name="unet", extra_args=[]): return noise_pred unet = UnetModel() - unet = unet.half().cuda() - inputs = tuple( - [ - inputs.half().cuda() if len(inputs.shape) != 0 else inputs - for inputs in model_input[args.version]["unet"] - ] - ) + if args.precision == "fp16": + unet = unet.half().cuda() + inputs = tuple( + [ + inputs.half().cuda() if len(inputs.shape) != 0 else inputs + for inputs in model_input[args.version]["unet"] + ] + ) + else: + inputs = model_input[args.version]["unet"] shark_unet = compile_through_fx( args, unet, diff --git a/web/models/stable_diffusion/opt_params.py b/web/models/stable_diffusion/opt_params.py index b7c351b2..80431393 100644 --- a/web/models/stable_diffusion/opt_params.py +++ b/web/models/stable_diffusion/opt_params.py @@ -23,6 +23,8 @@ def get_unet(args): else: bucket = "gs://shark_tank/stable_diffusion" model_name = "unet_1dec_fp16" + if args.version == "v2.1base": + model_name = "unet2base_8dec_fp16" iree_flags += [ "--iree-flow-enable-padding-linalg-ops", "--iree-flow-linalg-ops-padding-size=32", @@ -55,6 +57,8 @@ def get_vae(args): if args.precision == "fp16": bucket = "gs://shark_tank/stable_diffusion" model_name = "vae_1dec_fp16" + if args.version == "v2.1base": + model_name = "vae2base_8dec_fp16" iree_flags += [ "--iree-flow-enable-padding-linalg-ops", "--iree-flow-linalg-ops-padding-size=32", @@ -116,6 +120,8 @@ def get_clip(args): ) bucket = "gs://shark_tank/stable_diffusion" model_name = "clip_1dec_fp32" + if args.version == "v2.1base": + model_name = "clip2base_8dec_fp32" iree_flags += [ "--iree-flow-linalg-ops-padding-size=16", "--iree-flow-enable-padding-linalg-ops", diff --git a/web/models/stable_diffusion/utils.py b/web/models/stable_diffusion/utils.py index d8ba8dc8..f9647646 100644 --- a/web/models/stable_diffusion/utils.py +++ b/web/models/stable_diffusion/utils.py @@ -12,6 +12,7 @@ def set_iree_runtime_flags(args): ] if "vulkan" in args.device: set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) + return