diff --git a/apps/__init__.py b/apps/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/stable_diffusion/__init__.py b/apps/stable_diffusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/stable_diffusion/resources/base_model.json b/apps/stable_diffusion/resources/base_model.json new file mode 100644 index 00000000..fc70e9c9 --- /dev/null +++ b/apps/stable_diffusion/resources/base_model.json @@ -0,0 +1,98 @@ +{ + "stabilityai/stable-diffusion-2-1": { + "unet": { + "latents": { + "shape": [ + "1*batch_size", + 4, + "height", + "width" + ], + "dtype": "f32" + }, + "timesteps": { + "shape": [ + 1 + ], + "dtype": "f32" + }, + "embedding": { + "shape": [ + "2*batch_size", + "max_len", + 1024 + ], + "dtype": "f32" + }, + "guidance_scale": { + "shape": 2, + "dtype": "f32" + } + }, + "vae": { + "latents" : { + "shape" : [ + "1*batch_size",4,"height","width" + ], + "dtype":"f32" + } + }, + "clip": { + "token" : { + "shape" : [ + "2*batch_size", + "max_len" + ], + "dtype":"i64" + } + } + }, + "CompVis/stable-diffusion-v1-4": { + "unet": { + "latents": { + "shape": [ + "1*batch_size", + 4, + "height", + "width" + ], + "dtype": "f32" + }, + "timesteps": { + "shape": [ + 1 + ], + "dtype": "f32" + }, + "embedding": { + "shape": [ + "2*batch_size", + "max_len", + 768 + ], + "dtype": "f32" + }, + "guidance_scale": { + "shape": 2, + "dtype": "f32" + } + }, + "vae": { + "latents" : { + "shape" : [ + "1*batch_size",4,"height","width" + ], + "dtype":"f32" + } + }, + "clip": { + "token" : { + "shape" : [ + "2*batch_size", + "max_len" + ], + "dtype":"i64" + } + } + } +} diff --git a/apps/stable_diffusion/resources/model_db.json b/apps/stable_diffusion/resources/model_db.json new file mode 100644 index 00000000..efb88e16 --- /dev/null +++ b/apps/stable_diffusion/resources/model_db.json @@ -0,0 +1,177 @@ +[ + { + "stablediffusion/untuned":"gs://shark_tank/stable_diffusion", + "stablediffusion/tuned":"gs://shark_tank/sd_tuned", + "stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda", + "anythingv3/untuned":"gs://shark_tank/sd_anythingv3", + "anythingv3/tuned":"gs://shark_tank/sd_tuned", + "anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda", + "analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion", + "analogdiffusion/tuned":"gs://shark_tank/sd_tuned", + "analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda", + "openjourney/untuned":"gs://shark_tank/sd_openjourney", + "openjourney/tuned":"gs://shark_tank/sd_tuned", + "dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion" + }, + { + "stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16", + "stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned", + "stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned", + "stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32", + "stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16", + "stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned", + "stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned", + "stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16", + "stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32", + "stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32", + "stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16", + "stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2", + "stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned", + "stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64", + "stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned", + "stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned", + "stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16", + "stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned", + "stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned", + "stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16", + "stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned", + "stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned", + "stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32", + "stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64", + "stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16", + "stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae2_19dec_fp16", + "stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16", + "stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32", + "anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16", + "anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned", + "anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned", + "anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32", + "anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16", + "anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned", + "anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned", + "anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16", + "anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32", + "anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32", + "anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32", + "analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16", + "analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned", + "analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned", + "analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32", + "analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16", + "analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned", + "analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned", + "analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16", + "analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32", + "analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32", + "analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32", + "openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64", + "openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64", + "openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16", + "openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16", + "openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32", + "openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32", + "openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64", + "dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77", + "dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77", + "dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16", + "dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16", + "dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32", + "dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32", + "dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77" + }, + { + "unet": { + "tuned": { + "fp16": { + "default_compilation_flags": [] + }, + "fp32": { + "default_compilation_flags": [] + } + }, + "untuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32" + ], + "specified_compilation_flags": { + "cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"], + "default_device": ["--iree-flow-enable-conv-img2col-transform"] + } + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-enable-conv-nchw-to-nhwc-transform", + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=16" + ] + } + } + }, + "vae": { + "tuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32", + "--iree-flow-enable-conv-img2col-transform" + ] + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32", + "--iree-flow-enable-conv-img2col-transform" + ] + } + }, + "untuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32", + "--iree-flow-enable-conv-img2col-transform" + ] + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-enable-conv-nchw-to-nhwc-transform", + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=16" + ] + } + } + }, + "clip": { + "tuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-linalg-ops-padding-size=16", + "--iree-flow-enable-padding-linalg-ops" + ] + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-linalg-ops-padding-size=16", + "--iree-flow-enable-padding-linalg-ops" + ] + } + }, + "untuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-linalg-ops-padding-size=16", + "--iree-flow-enable-padding-linalg-ops" + ] + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-linalg-ops-padding-size=16", + "--iree-flow-enable-padding-linalg-ops" + ] + } + } + } + } +] diff --git a/apps/stable_diffusion/resources/opt_flags.json b/apps/stable_diffusion/resources/opt_flags.json new file mode 100644 index 00000000..ae30855b --- /dev/null +++ b/apps/stable_diffusion/resources/opt_flags.json @@ -0,0 +1,95 @@ + { + "unet": { + "tuned": { + "fp16": { + "default_compilation_flags": [] + }, + "fp32": { + "default_compilation_flags": [] + } + }, + "untuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32" + ], + "specified_compilation_flags": { + "cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"], + "default_device": ["--iree-flow-enable-conv-img2col-transform"] + } + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-enable-conv-nchw-to-nhwc-transform", + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=16" + ] + } + } + }, + "vae": { + "tuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32", + "--iree-flow-enable-conv-img2col-transform" + ] + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32", + "--iree-flow-enable-conv-img2col-transform" + ] + } + }, + "untuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=32", + "--iree-flow-enable-conv-img2col-transform" + ] + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-enable-conv-nchw-to-nhwc-transform", + "--iree-flow-enable-padding-linalg-ops", + "--iree-flow-linalg-ops-padding-size=16" + ] + } + } + }, + "clip": { + "tuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-linalg-ops-padding-size=16", + "--iree-flow-enable-padding-linalg-ops" + ] + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-linalg-ops-padding-size=16", + "--iree-flow-enable-padding-linalg-ops" + ] + } + }, + "untuned": { + "fp16": { + "default_compilation_flags": [ + "--iree-flow-linalg-ops-padding-size=16", + "--iree-flow-enable-padding-linalg-ops" + ] + }, + "fp32": { + "default_compilation_flags": [ + "--iree-flow-linalg-ops-padding-size=16", + "--iree-flow-enable-padding-linalg-ops" + ] + } + } + } + } diff --git a/apps/stable_diffusion/resources/prompts.json b/apps/stable_diffusion/resources/prompts.json new file mode 100644 index 00000000..4c8370db --- /dev/null +++ b/apps/stable_diffusion/resources/prompts.json @@ -0,0 +1,8 @@ +[["A high tech solarpunk utopia in the Amazon rainforest"], +["A pikachu fine dining with a view to the Eiffel Tower"], +["A mecha robot in a favela in expressionist style"], +["an insect robot preparing a delicious meal"], +["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"], +["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"], +["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"], +["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]] diff --git a/apps/stable_diffusion/scripts/__init__.py b/apps/stable_diffusion/scripts/__init__.py new file mode 100644 index 00000000..b4fcdd64 --- /dev/null +++ b/apps/stable_diffusion/scripts/__init__.py @@ -0,0 +1 @@ +from .txt2img import txt2img_inf diff --git a/apps/stable_diffusion/scripts/img2img.py b/apps/stable_diffusion/scripts/img2img.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/stable_diffusion/scripts/txt2img.py b/apps/stable_diffusion/scripts/txt2img.py new file mode 100644 index 00000000..cb36e72c --- /dev/null +++ b/apps/stable_diffusion/scripts/txt2img.py @@ -0,0 +1,240 @@ +import os + +os.environ["AMD_ENABLE_LLPC"] = "1" + +import torch +import re +import time +from pathlib import Path +from datetime import datetime as dt +from dataclasses import dataclass +from csv import DictWriter +from apps.stable_diffusion.src import ( + args, + Text2ImagePipeline, + get_schedulers, + set_init_device_flags, +) + + +@dataclass +class Config: + model_id: str + ckpt_loc: str + precision: str + batch_size: int + max_length: int + height: int + width: int + device: str + + +# This has to come before importing cache objects +if args.clear_all: + print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE") + from glob import glob + import shutil + + vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb")) + for vmfb in vmfbs: + if os.path.exists(vmfb): + os.remove(vmfb) + home = os.path.expanduser("~") + if os.name == "nt": # Windows + appdata = os.getenv("LOCALAPPDATA") + shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True) + shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True) + elif os.name == "unix": + shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache")) + shutil.rmtree(os.path.join(home, ".local/shark_tank")) + + +# save output images and the inputs correspoding to it. +def save_output_img(output_img): + output_path = args.output_dir if args.output_dir else Path.cwd() + generated_imgs_path = Path(output_path, "generated_imgs") + generated_imgs_path.mkdir(parents=True, exist_ok=True) + csv_path = Path(generated_imgs_path, "imgs_details.csv") + + prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15]) + out_img_name = ( + f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}" + ) + out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg") + output_img.save(out_img_path, quality=95, subsampling=0) + + new_entry = { + "VARIANT": args.hf_model_id, + "SCHEDULER": args.scheduler, + "PROMPT": args.prompts[0], + "NEG_PROMPT": args.negative_prompts[0], + "SEED": args.seed, + "CFG_SCALE": args.guidance_scale, + "PRECISION": args.precision, + "STEPS": args.steps, + "HEIGHT": args.height, + "WIDTH": args.width, + "MAX_LENGTH": args.max_length, + "OUTPUT": out_img_path, + } + + with open(csv_path, "a") as csv_obj: + dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys())) + dictwriter_obj.writerow(new_entry) + csv_obj.close() + + +txt2img_obj = None +config_obj = None +schedulers = None + +# Exposed to UI. +def txt2img_inf( + prompt: str, + negative_prompt: str, + height: int, + width: int, + steps: int, + guidance_scale: float, + seed: int, + batch_size: int, + scheduler: str, + model_id: str, + custom_model_id: str, + ckpt_file_obj, + precision: str, + device: str, + max_length: int, +): + global txt2img_obj + global config_obj + global schedulers + + args.prompts = [prompt] + args.negative_prompts = [negative_prompt] + args.guidance_scale = guidance_scale + args.seed = seed + args.steps = steps + args.scheduler = scheduler + args.hf_model_id = custom_model_id if custom_model_id else model_id + args.ckpt_loc = ckpt_file_obj.name if ckpt_file_obj else "" + dtype = torch.float32 if precision == "fp32" else torch.half + cpu_scheduling = not scheduler.startswith("Shark") + new_config_obj = Config( + args.hf_model_id, + args.ckpt_loc, + precision, + batch_size, + max_length, + height, + width, + device, + ) + if config_obj != new_config_obj: + config_obj = new_config_obj + args.precision = precision + args.batch_size = batch_size + args.max_length = max_length + args.height = height + args.width = width + args.device = device.split("=>", 1)[1].strip() + args.use_tuned = True + args.import_mlir = False + set_init_device_flags() + schedulers = get_schedulers(model_id) + scheduler_obj = schedulers[scheduler] + txt2img_obj = Text2ImagePipeline.from_pretrained( + scheduler_obj, + args.import_mlir, + args.hf_model_id, + args.ckpt_loc, + args.precision, + args.max_length, + args.batch_size, + args.height, + args.width, + args.use_base_vae, + ) + txt2img_obj.scheduler = schedulers[scheduler] + + start_time = time.time() + txt2img_obj.log = "" + generated_imgs = txt2img_obj.generate_images( + prompt, + negative_prompt, + batch_size, + height, + width, + steps, + guidance_scale, + seed, + args.max_length, + dtype, + args.use_base_vae, + cpu_scheduling, + ) + total_time = time.time() - start_time + save_output_img(generated_imgs[0]) + text_output = f"prompt={args.prompts}" + text_output += f"\nnegative prompt={args.negative_prompts}" + text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}" + text_output += f"\nscheduler={args.scheduler}, device={device}" + text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}" + text_output += ( + f", batch size={args.batch_size}, max_length={args.max_length}" + ) + text_output += txt2img_obj.log + text_output += f"\nTotal image generation time: {total_time:.4f}sec" + + return generated_imgs, text_output + + +if __name__ == "__main__": + dtype = torch.float32 if args.precision == "fp32" else torch.half + cpu_scheduling = not args.scheduler.startswith("Shark") + set_init_device_flags() + schedulers = get_schedulers(args.hf_model_id) + scheduler_obj = schedulers[args.scheduler] + + txt2img_obj = Text2ImagePipeline.from_pretrained( + scheduler_obj, + args.import_mlir, + args.hf_model_id, + args.ckpt_loc, + args.precision, + args.max_length, + args.batch_size, + args.height, + args.width, + args.use_base_vae, + ) + + start_time = time.time() + generated_imgs = txt2img_obj.generate_images( + args.prompts, + args.negative_prompts, + args.batch_size, + args.height, + args.width, + args.steps, + args.guidance_scale, + args.seed, + args.max_length, + dtype, + args.use_base_vae, + cpu_scheduling, + ) + total_time = time.time() - start_time + text_output = f"prompt={args.prompts}" + text_output += f"\nnegative prompt={args.negative_prompts}" + text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}" + text_output += f"\nscheduler={args.scheduler}, device={args.device}" + text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}" + text_output += ( + f", batch size={args.batch_size}, max_length={args.max_length}" + ) + text_output += txt2img_obj.log + text_output += f"\nTotal image generation time: {total_time:.4f}sec" + + save_output_img(generated_imgs[0]) + print(text_output) diff --git a/apps/stable_diffusion/src/__init__.py b/apps/stable_diffusion/src/__init__.py new file mode 100644 index 00000000..5184720b --- /dev/null +++ b/apps/stable_diffusion/src/__init__.py @@ -0,0 +1,8 @@ +from .utils import ( + args, + set_init_device_flags, + prompt_examples, + get_available_devices, +) +from .pipelines import Text2ImagePipeline +from .schedulers import get_schedulers diff --git a/apps/stable_diffusion/src/models/__init__.py b/apps/stable_diffusion/src/models/__init__.py new file mode 100644 index 00000000..2ddbbe97 --- /dev/null +++ b/apps/stable_diffusion/src/models/__init__.py @@ -0,0 +1,2 @@ +from .model_wrappers import SharkifyStableDiffusionModel +from .opt_params import get_vae, get_unet, get_clip, get_tokenizer diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py new file mode 100644 index 00000000..f9d33fea --- /dev/null +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -0,0 +1,229 @@ +from diffusers import AutoencoderKL, UNet2DConditionModel +from transformers import CLIPTextModel +from collections import defaultdict +import torch +import sys +import traceback +import re +from ..utils import compile_through_fx, get_opt_flags, base_models, args + + +# These shapes are parameter dependent. +def replace_shape_str(shape, max_len, width, height, batch_size): + new_shape = [] + for i in range(len(shape)): + if shape[i] == "max_len": + new_shape.append(max_len) + elif shape[i] == "height": + new_shape.append(height) + elif shape[i] == "width": + new_shape.append(width) + elif isinstance(shape[i], str): + if "batch_size" in shape[i]: + mul_val = int(shape[i].split("*")[0]) + new_shape.append(batch_size * mul_val) + else: + new_shape.append(shape[i]) + return new_shape + + +# Get the input info for various models i.e. "unet", "clip", "vae". +def get_input_info(model_info, max_len, width, height, batch_size): + dtype_config = {"f32": torch.float32, "i64": torch.int64} + input_map = defaultdict(list) + for k in model_info: + for inp in model_info[k]: + shape = model_info[k][inp]["shape"] + dtype = dtype_config[model_info[k][inp]["dtype"]] + tensor = None + if isinstance(shape, list): + clean_shape = replace_shape_str( + shape, max_len, width, height, batch_size + ) + if dtype == torch.int64: + tensor = torch.randint(1, 3, tuple(clean_shape)) + else: + tensor = torch.randn(*clean_shape).to(dtype) + elif isinstance(shape, int): + tensor = torch.tensor(shape).to(dtype) + else: + sys.exit("shape isn't specified correctly.") + input_map[k].append(tensor) + return input_map + + +class SharkifyStableDiffusionModel: + def __init__( + self, + model_id: str, + custom_weights: str, + precision: str, + max_len: int = 64, + width: int = 512, + height: int = 512, + batch_size: int = 1, + use_base_vae: bool = False, + ): + self.check_params(max_len, width, height) + self.max_len = max_len + self.height = height // 8 + self.width = width // 8 + self.batch_size = batch_size + self.model_id = model_id if custom_weights == "" else custom_weights + self.precision = precision + self.base_vae = use_base_vae + self.model_name = ( + str(batch_size) + + "_" + + str(max_len) + + "_" + + str(height) + + "_" + + str(width) + + "_" + + precision + ) + # We need a better naming convention for the .vmfbs because despite + # using the custom model variant the .vmfb names remain the same and + # it'll always pick up the compiled .vmfb instead of compiling the + # custom model. + # So, currently, we add `self.model_id` in the `self.model_name` of + # .vmfb file. + # TODO: Have a better way of naming the vmfbs using self.model_name. + + model_name = re.sub(r"\W+", "_", self.model_id) + if model_name[0] == "_": + model_name = model_name[1:] + self.model_name = self.model_name + "_" + model_name + + def check_params(self, max_len, width, height): + if not (max_len >= 32 and max_len <= 77): + sys.exit("please specify max_len in the range [32, 77].") + if not (width % 8 == 0 and width >= 384): + sys.exit("width should be greater than 384 and multiple of 8") + if not (height % 8 == 0 and height >= 384): + sys.exit("height should be greater than 384 and multiple of 8") + + def get_vae(self): + class VaeModel(torch.nn.Module): + def __init__(self, model_id=self.model_id, base_vae=self.base_vae): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + model_id, + subfolder="vae", + ) + self.base_vae = base_vae + + def forward(self, input): + if not self.base_vae: + input = 1 / 0.18215 * input + x = self.vae.decode(input, return_dict=False)[0] + x = (x / 2 + 0.5).clamp(0, 1) + if self.base_vae: + return x + x = x * 255.0 + return x.round() + + vae = VaeModel() + inputs = tuple(self.inputs["vae"]) + is_f16 = True if self.precision == "fp16" else False + vae_name = "base_vae" if self.base_vae else "vae" + shark_vae = compile_through_fx( + vae, + inputs, + is_f16=is_f16, + model_name=vae_name + self.model_name, + extra_args=get_opt_flags("vae", precision=self.precision), + ) + return shark_vae + + def get_unet(self): + class UnetModel(torch.nn.Module): + def __init__(self, model_id=self.model_id): + super().__init__() + self.unet = UNet2DConditionModel.from_pretrained( + model_id, + subfolder="unet", + ) + self.in_channels = self.unet.in_channels + self.train(False) + + def forward( + self, latent, timestep, text_embedding, guidance_scale + ): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latents = torch.cat([latent] * 2) + unet_out = self.unet.forward( + latents, timestep, text_embedding, return_dict=False + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + unet = UnetModel() + is_f16 = True if self.precision == "fp16" else False + inputs = tuple(self.inputs["unet"]) + input_mask = [True, True, True, False] + shark_unet = compile_through_fx( + unet, + inputs, + model_name="unet" + self.model_name, + is_f16=is_f16, + f16_input_mask=input_mask, + extra_args=get_opt_flags("unet", precision=self.precision), + ) + return shark_unet + + def get_clip(self): + class CLIPText(torch.nn.Module): + def __init__(self, model_id=self.model_id): + super().__init__() + self.text_encoder = CLIPTextModel.from_pretrained( + model_id, + subfolder="text_encoder", + ) + + def forward(self, input): + return self.text_encoder(input)[0] + + clip_model = CLIPText() + + shark_clip = compile_through_fx( + clip_model, + tuple(self.inputs["clip"]), + model_name="clip" + self.model_name, + extra_args=get_opt_flags("clip", precision="fp32"), + ) + return shark_clip + + def __call__(self): + + for model_id in base_models: + self.inputs = get_input_info( + base_models[model_id], + self.max_len, + self.width, + self.height, + self.batch_size, + ) + try: + compiled_clip = self.get_clip() + compiled_unet = self.get_unet() + compiled_vae = self.get_vae() + except Exception as e: + if args.enable_stack_trace: + traceback.print_exc() + print("Retrying with a different base model configuration") + continue + # This is done just because in main.py we are basing the choice of tokenizer and scheduler + # on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base + # model and rely on retrying method to find the input configuration, we should also update + # the knowledge of base model id accordingly into `args.hf_model_id`. + if args.ckpt_loc != "": + args.hf_model_id = model_id + return compiled_clip, compiled_unet, compiled_vae + sys.exit( + "Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues" + ) diff --git a/apps/stable_diffusion/src/models/opt_params.py b/apps/stable_diffusion/src/models/opt_params.py new file mode 100644 index 00000000..caf9a768 --- /dev/null +++ b/apps/stable_diffusion/src/models/opt_params.py @@ -0,0 +1,113 @@ +import sys +from transformers import CLIPTokenizer +from ..utils import models_db, args, get_shark_model + + +hf_model_variant_map = { + "Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"], + "dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"], + "prompthero/openjourney": ["openjourney", "v2_1base"], + "wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"], + "stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1"], + "stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"], + "CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"], +} + + +def get_params(bucket_key, model_key, model, is_tuned, precision): + iree_flags = [] + if len(args.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" + ) + + # Disable bindings fusion to work with moltenVK. + if sys.platform == "darwin": + iree_flags.append("-iree-stream-fuse-binding=false") + + try: + bucket = models_db[0][bucket_key] + model_name = models_db[1][model_key] + iree_flags += models_db[2][model][is_tuned][precision][ + "default_compilation_flags" + ] + except KeyError: + raise Exception( + f"{bucket_key}/{model_key} is not present in the models database" + ) + + if ( + "specified_compilation_flags" + in models_db[2][model][is_tuned][precision] + ): + device = ( + args.device + if "://" not in args.device + else args.device.split("://")[0] + ) + if ( + device + not in models_db[2][model][is_tuned][precision][ + "specified_compilation_flags" + ] + ): + device = "default_device" + iree_flags += models_db[2][model][is_tuned][precision][ + "specified_compilation_flags" + ][device] + + return bucket, model_name, iree_flags + + +def get_unet(): + variant, version = hf_model_variant_map[args.hf_model_id] + # Tuned model is present only for `fp16` precision. + is_tuned = "tuned" if args.use_tuned else "untuned" + if "vulkan" not in args.device and args.use_tuned: + bucket_key = f"{variant}/{is_tuned}/{args.device}" + model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}" + else: + bucket_key = f"{variant}/{is_tuned}" + model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}" + + bucket, model_name, iree_flags = get_params( + bucket_key, model_key, "unet", is_tuned, args.precision + ) + return get_shark_model(bucket, model_name, iree_flags) + + +def get_vae(): + variant, version = hf_model_variant_map[args.hf_model_id] + # Tuned model is present only for `fp16` precision. + is_tuned = "tuned" if args.use_tuned else "untuned" + is_base = "/base" if args.use_base_vae else "" + if "vulkan" not in args.device and args.use_tuned: + bucket_key = f"{variant}/{is_tuned}/{args.device}" + model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}" + else: + bucket_key = f"{variant}/{is_tuned}" + model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}" + + bucket, model_name, iree_flags = get_params( + bucket_key, model_key, "vae", is_tuned, args.precision + ) + return get_shark_model(bucket, model_name, iree_flags) + + +def get_clip(): + variant, version = hf_model_variant_map[args.hf_model_id] + bucket_key = f"{variant}/untuned" + model_key = ( + f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned" + ) + bucket, model_name, iree_flags = get_params( + bucket_key, model_key, "clip", "untuned", "fp32" + ) + return get_shark_model(bucket, model_name, iree_flags) + + +def get_tokenizer(): + tokenizer = CLIPTokenizer.from_pretrained( + args.hf_model_id, subfolder="tokenizer" + ) + return tokenizer diff --git a/apps/stable_diffusion/src/pipelines/__init__.py b/apps/stable_diffusion/src/pipelines/__init__.py new file mode 100644 index 00000000..9ed9f966 --- /dev/null +++ b/apps/stable_diffusion/src/pipelines/__init__.py @@ -0,0 +1 @@ +from .pipeline_shark_stable_diffusion_txt2img import Text2ImagePipeline diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py new file mode 100644 index 00000000..8fc96359 --- /dev/null +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py @@ -0,0 +1,132 @@ +import torch +from tqdm.auto import tqdm +import numpy as np +from random import randint +from transformers import CLIPTokenizer +from typing import Union +from shark.shark_inference import SharkInference +from diffusers import ( + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, +) +from ..schedulers import SharkEulerDiscreteScheduler +from .pipeline_shark_stable_diffusion_utils import StableDiffusionPipeline + + +class Text2ImagePipeline(StableDiffusionPipeline): + def __init__( + self, + vae: SharkInference, + text_encoder: SharkInference, + tokenizer: CLIPTokenizer, + unet: SharkInference, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + SharkEulerDiscreteScheduler, + ], + ): + super().__init__(vae, text_encoder, tokenizer, unet, scheduler) + + def prepare_latents( + self, + batch_size, + height, + width, + generator, + num_inference_steps, + dtype, + ): + latents = torch.randn( + ( + batch_size, + 4, + height // 8, + width // 8, + ), + generator=generator, + dtype=torch.float32, + ).to(dtype) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.is_scale_input_called = True + latents = latents * self.scheduler.init_noise_sigma + return latents + + def generate_images( + self, + prompts, + neg_prompts, + batch_size, + height, + width, + num_inference_steps, + guidance_scale, + seed, + max_length, + dtype, + use_base_vae, + cpu_scheduling, + ): + # prompts and negative prompts must be a list. + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(neg_prompts, str): + neg_prompts = [neg_prompts] + + prompts = prompts * batch_size + neg_prompts = neg_prompts * batch_size + + # seed generator to create the inital latent noise. Also handle out of range seeds. + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + generator = torch.manual_seed(seed) + + # Get initial latents + init_latents = self.prepare_latents( + batch_size=batch_size, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + dtype=dtype, + ) + + # Get text embeddings from prompts + text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length) + + # guidance scale as a float32 tensor. + guidance_scale = torch.tensor(guidance_scale).to(torch.float32) + + # Get Image latents + latents = self.produce_img_latents( + latents=init_latents, + text_embeddings=text_embeddings, + guidance_scale=guidance_scale, + total_timesteps=self.scheduler.timesteps, + dtype=dtype, + cpu_scheduling=cpu_scheduling, + ) + + # Img latents -> PIL images + all_imgs = [] + for i in tqdm(range(0, latents.shape[0], batch_size)): + imgs = self.decode_latents( + latents=latents[i : i + batch_size], + use_base_vae=use_base_vae, + cpu_scheduling=cpu_scheduling, + ) + all_imgs.extend(imgs) + + return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py new file mode 100644 index 00000000..a3408ec2 --- /dev/null +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -0,0 +1,206 @@ +import torch +from transformers import CLIPTokenizer +import torchvision.transforms as T +from tqdm.auto import tqdm +import time +from typing import Union +from diffusers import ( + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, +) +from shark.shark_inference import SharkInference +from ..schedulers import SharkEulerDiscreteScheduler +from ..models import ( + SharkifyStableDiffusionModel, + get_vae, + get_clip, + get_unet, + get_tokenizer, +) +from ..utils import start_profiling, end_profiling, preprocessCKPT + + +class StableDiffusionPipeline: + def __init__( + self, + vae: SharkInference, + text_encoder: SharkInference, + tokenizer: CLIPTokenizer, + unet: SharkInference, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + SharkEulerDiscreteScheduler, + ], + ): + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.unet = unet + self.scheduler = scheduler + # TODO: Implement using logging python utility. + self.log = "" + + def encode_prompts(self, prompts, neg_prompts, max_length): + # Tokenize text and get embeddings + text_input = self.tokenizer( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + # Get unconditional embeddings as well + uncond_input = self.tokenizer( + neg_prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + text_input = torch.cat([uncond_input.input_ids, text_input.input_ids]) + + clip_inf_start = time.time() + text_embeddings = self.text_encoder("forward", (text_input,)) + clip_inf_time = (time.time() - clip_inf_start) * 1000 + self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}" + + return text_embeddings + + def decode_latents(self, latents, use_base_vae, cpu_scheduling): + if use_base_vae: + latents = 1 / 0.18215 * latents + + latents_numpy = latents + if cpu_scheduling: + latents_numpy = latents.detach().numpy() + + profile_device = start_profiling(file_path="vae.rdc") + vae_start = time.time() + images = self.vae("forward", (latents_numpy,)) + vae_inf_time = (time.time() - vae_start) * 1000 + end_profiling(profile_device) + self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}" + + if use_base_vae: + images = torch.from_numpy(images) + images = (images.detach().cpu() * 255.0).numpy() + images = images.round() + + transform = T.ToPILImage() + pil_images = [ + transform(image) + for image in torch.from_numpy(images).to(torch.uint8) + ] + return pil_images + + def produce_img_latents( + self, + latents, + text_embeddings, + guidance_scale, + total_timesteps, + dtype, + cpu_scheduling, + return_all_latents=False, + ): + + step_time_sum = 0 + latent_history = [latents] + text_embeddings = torch.from_numpy(text_embeddings).to(dtype) + text_embeddings_numpy = text_embeddings.detach().numpy() + for i, t in tqdm(enumerate(total_timesteps)): + step_start_time = time.time() + timestep = torch.tensor([t]).to(dtype).detach().numpy() + latent_model_input = self.scheduler.scale_model_input(latents, t) + if cpu_scheduling: + latent_model_input = latent_model_input.detach().numpy() + + # Profiling Unet. + profile_device = start_profiling(file_path="unet.rdc") + noise_pred = self.unet( + "forward", + ( + latent_model_input, + timestep, + text_embeddings_numpy, + guidance_scale, + ), + send_to_host=False, + ) + end_profiling(profile_device) + + if cpu_scheduling: + noise_pred = torch.from_numpy(noise_pred.to_host()) + latents = self.scheduler.step( + noise_pred, t, latents + ).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents) + + latent_history.append(latents) + step_time = (time.time() - step_start_time) * 1000 + # self.log += ( + # f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms" + # ) + step_time_sum += step_time + + avg_step_time = step_time_sum / len(total_timesteps) + self.log += f"\nAverage step time: {avg_step_time}ms/it" + + if not return_all_latents: + return latents + all_latents = torch.cat(latent_history, dim=0) + return all_latents + + @classmethod + def from_pretrained( + cls, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + SharkEulerDiscreteScheduler, + ], + import_mlir: bool, + model_id: str, + ckpt_loc: str, + precision: str, + max_length: int, + batch_size: int, + height: int, + width: int, + use_base_vae: bool, + ): + init_kwargs = None + if import_mlir: + if ckpt_loc: + preprocessCKPT() + mlir_import = SharkifyStableDiffusionModel( + model_id, + ckpt_loc, + precision, + max_len=max_length, + batch_size=batch_size, + height=height, + width=width, + use_base_vae=use_base_vae, + ) + clip, unet, vae = mlir_import() + return cls(vae, clip, get_tokenizer(), unet, scheduler) + return cls( + get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler + ) diff --git a/apps/stable_diffusion/src/schedulers/__init__.py b/apps/stable_diffusion/src/schedulers/__init__.py new file mode 100644 index 00000000..d143edc6 --- /dev/null +++ b/apps/stable_diffusion/src/schedulers/__init__.py @@ -0,0 +1,2 @@ +from .sd_schedulers import get_schedulers +from .shark_eulerdiscrete import SharkEulerDiscreteScheduler diff --git a/apps/stable_diffusion/src/schedulers/sd_schedulers.py b/apps/stable_diffusion/src/schedulers/sd_schedulers.py new file mode 100644 index 00000000..514fce4c --- /dev/null +++ b/apps/stable_diffusion/src/schedulers/sd_schedulers.py @@ -0,0 +1,51 @@ +from diffusers import ( + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, +) +from .shark_eulerdiscrete import ( + SharkEulerDiscreteScheduler, +) + + +def get_schedulers(model_id): + schedulers = dict() + schedulers["PNDM"] = PNDMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DDIM"] = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers[ + "DPMSolverMultistep" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers[ + "EulerAncestralDiscrete" + ] = EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers[ + "SharkEulerDiscrete" + ] = SharkEulerDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["SharkEulerDiscrete"].compile() + return schedulers diff --git a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py new file mode 100644 index 00000000..29494d4c --- /dev/null +++ b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py @@ -0,0 +1,139 @@ +import sys +import numpy as np +from typing import List, Optional, Tuple, Union +from diffusers import ( + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerDiscreteScheduler, +) +from diffusers.configuration_utils import register_to_config +from ..utils import compile_through_fx, get_shark_model, args +import torch + + +class SharkEulerDiscreteScheduler(EulerDiscreteScheduler): + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + ): + super().__init__( + num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + trained_betas, + prediction_type, + ) + + def compile(self): + SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers" + BATCH_SIZE = args.batch_size + + model_input = { + "euler": { + "latent": torch.randn( + BATCH_SIZE, 4, args.height // 8, args.width // 8 + ), + "output": torch.randn( + BATCH_SIZE, 4, args.height // 8, args.width // 8 + ), + "sigma": torch.tensor(1).to(torch.float32), + "dt": torch.tensor(1).to(torch.float32), + }, + } + + example_latent = model_input["euler"]["latent"] + example_output = model_input["euler"]["output"] + if args.precision == "fp16": + example_latent = example_latent.half() + example_output = example_output.half() + example_sigma = model_input["euler"]["sigma"] + example_dt = model_input["euler"]["dt"] + + class ScalingModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, latent, sigma): + return latent / ((sigma**2 + 1) ** 0.5) + + class SchedulerStepModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, noise_pred, sigma, latent, dt): + pred_original_sample = latent - sigma * noise_pred + derivative = (latent - pred_original_sample) / sigma + return latent + derivative * dt + + iree_flags = [] + if len(args.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" + ) + # Disable bindings fusion to work with moltenVK. + if sys.platform == "darwin": + iree_flags.append("-iree-stream-fuse-binding=false") + + if args.import_mlir: + scaling_model = ScalingModel() + self.scaling_model = compile_through_fx( + scaling_model, + (example_latent, example_sigma), + model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}" + + args.precision, + extra_args=iree_flags, + ) + + step_model = SchedulerStepModel() + self.step_model = compile_through_fx( + step_model, + (example_output, example_sigma, example_latent, example_dt), + model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}" + + args.precision, + extra_args=iree_flags, + ) + else: + self.scaling_model = get_shark_model( + SCHEDULER_BUCKET, + "euler_scale_model_input_" + args.precision, + iree_flags, + ) + self.step_model = get_shark_model( + SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags + ) + + def scale_model_input(self, sample, timestep): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + return self.scaling_model( + "forward", + ( + sample, + sigma, + ), + send_to_host=False, + ) + + def step(self, noise_pred, timestep, latent): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + dt = self.sigmas[step_index + 1] - sigma + return self.step_model( + "forward", + ( + noise_pred, + sigma, + latent, + dt, + ), + send_to_host=False, + ) diff --git a/apps/stable_diffusion/src/utils/__init__.py b/apps/stable_diffusion/src/utils/__init__.py new file mode 100644 index 00000000..f2108583 --- /dev/null +++ b/apps/stable_diffusion/src/utils/__init__.py @@ -0,0 +1,19 @@ +from .profiler import start_profiling, end_profiling +from .resources import ( + prompt_examples, + models_db, + base_models, + opt_flags, + resource_path, +) +from .stable_args import args +from .utils import ( + get_shark_model, + compile_through_fx, + set_iree_runtime_flags, + map_device_to_name_path, + set_init_device_flags, + get_available_devices, + get_opt_flags, + preprocessCKPT, +) diff --git a/apps/stable_diffusion/src/utils/profiler.py b/apps/stable_diffusion/src/utils/profiler.py new file mode 100644 index 00000000..34ccff02 --- /dev/null +++ b/apps/stable_diffusion/src/utils/profiler.py @@ -0,0 +1,17 @@ +from .stable_args import args + +# Helper function to profile the vulkan device. +def start_profiling(file_path="foo.rdc", profiling_mode="queue"): + if args.vulkan_debug_utils and "vulkan" in args.device: + import iree + + print(f"Profiling and saving to {file_path}.") + vulkan_device = iree.runtime.get_device(args.device) + vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path) + return vulkan_device + return None + + +def end_profiling(device): + if device: + return device.end_profiling() diff --git a/apps/stable_diffusion/src/utils/resources.py b/apps/stable_diffusion/src/utils/resources.py new file mode 100644 index 00000000..a0ef07c5 --- /dev/null +++ b/apps/stable_diffusion/src/utils/resources.py @@ -0,0 +1,37 @@ +import os +import json +import sys + + +def resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr( + sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) + ) + return os.path.join(base_path, relative_path) + + +def get_json_file(path): + json_var = [] + loc_json = resource_path(path) + if os.path.exists(loc_json): + with open(loc_json, encoding="utf-8") as fopen: + json_var = json.load(fopen) + + if not json_var: + print(f"Unable to fetch {path}") + + return json_var + + +# TODO: This shouldn't be called from here, every time the file imports +# it will run all the global vars. +prompt_examples = get_json_file("../../resources/prompts.json") +models_db = get_json_file("../../resources/model_db.json") + +# The base_model contains the input configuration for the different +# models and also helps in providing information for the variants. +base_models = get_json_file("../../resources/base_model.json") + +# Contains optimization flags for different models. +opt_flags = get_json_file("../../resources/opt_flags.json") diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py new file mode 100644 index 00000000..ccb53346 --- /dev/null +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -0,0 +1,323 @@ +import argparse +from pathlib import Path + + +def path_expand(s): + return Path(s).expanduser().resolve() + + +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +############################################################################## +### Stable Diffusion Params +############################################################################## + +p.add_argument( + "-p", + "--prompts", + action="append", + default=[], + help="text of which images to be generated.", +) + +p.add_argument( + "--negative-prompts", + nargs="+", + default=[""], + help="text you don't want to see in the generated image.", +) + +p.add_argument( + "--steps", + type=int, + default=50, + help="the no. of steps to do the sampling.", +) + +p.add_argument( + "--seed", + type=int, + default=42, + help="the seed to use.", +) + +p.add_argument( + "--batch_size", + type=int, + default=1, + choices=range(1, 4), + help="the number of inferences to be made in a single `run`.", +) + +p.add_argument( + "--height", + type=int, + default=512, + help="the height of the output image.", +) + +p.add_argument( + "--width", + type=int, + default=512, + help="the width of the output image.", +) + +p.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="the value to be used for guidance scaling.", +) + +p.add_argument( + "--max_length", + type=int, + default=64, + help="max length of the tokenizer output, options are 64 and 77.", +) + +############################################################################## +### Model Config and Usage Params +############################################################################## + +p.add_argument( + "--device", type=str, default="vulkan", help="device to run the model." +) + +p.add_argument( + "--precision", type=str, default="fp16", help="precision to run the model." +) + +p.add_argument( + "--import_mlir", + default=False, + action=argparse.BooleanOptionalAction, + help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.", +) + +p.add_argument( + "--load_vmfb", + default=True, + action=argparse.BooleanOptionalAction, + help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.", +) + +p.add_argument( + "--save_vmfb", + default=False, + action=argparse.BooleanOptionalAction, + help="saves the compiled flatbuffer to the local directory", +) + +p.add_argument( + "--use_tuned", + default=True, + action=argparse.BooleanOptionalAction, + help="Download and use the tuned version of the model if available", +) + +p.add_argument( + "--use_base_vae", + default=False, + action=argparse.BooleanOptionalAction, + help="Do conversion from the VAE output to pixel space on cpu.", +) + +p.add_argument( + "--scheduler", + type=str, + default="SharkEulerDiscrete", + help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]", +) + +p.add_argument( + "--output_img_format", + type=str, + default="png", + help="specify the format in which output image is save. Supported options: jpg / png", +) + +p.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory path to save the output images and json", +) + +p.add_argument( + "--runs", + type=int, + default=1, + help="number of images to be generated with random seeds in single execution", +) + +p.add_argument( + "--ckpt_loc", + type=str, + default="", + help="Path to SD's .ckpt file.", +) + +p.add_argument( + "--hf_model_id", + type=str, + default="stabilityai/stable-diffusion-2-1-base", + help="The repo-id of hugging face.", +) + +p.add_argument( + "--enable_stack_trace", + default=False, + action=argparse.BooleanOptionalAction, + help="Enable showing the stack trace when retrying the base model configuration", +) + +############################################################################## +### IREE - Vulkan supported flags +############################################################################## + +p.add_argument( + "--iree-vulkan-target-triple", + type=str, + default="", + help="Specify target triple for vulkan", +) + +p.add_argument( + "--vulkan_debug_utils", + default=False, + action=argparse.BooleanOptionalAction, + help="Profiles vulkan device and collects the .rdc info", +) + +p.add_argument( + "--vulkan_large_heap_block_size", + default="4147483648", + help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G", +) + +p.add_argument( + "--vulkan_validation_layers", + default=False, + action=argparse.BooleanOptionalAction, + help="flag for disabling vulkan validation layers when benchmarking", +) + +############################################################################## +### Misc. Debug and Optimization flags +############################################################################## + +p.add_argument( + "--use_compiled_scheduler", + default=True, + action=argparse.BooleanOptionalAction, + help="use the default scheduler precompiled into the model if available", +) + +p.add_argument( + "--local_tank_cache", + default="", + help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.", +) + +p.add_argument( + "--dump_isa", + default=False, + action="store_true", + help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.", +) + +p.add_argument( + "--dispatch_benchmarks", + default=None, + help='dispatches to return benchamrk data on. use "All" for all, and None for none.', +) + +p.add_argument( + "--dispatch_benchmarks_dir", + default="temp_dispatch_benchmarks", + help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"', +) + +p.add_argument( + "--enable_rgp", + default=False, + action=argparse.BooleanOptionalAction, + help="flag for inserting debug frames between iterations for use with rgp.", +) + +p.add_argument( + "--hide_steps", + default=True, + action=argparse.BooleanOptionalAction, + help="flag for hiding the details of iteration/sec for each step.", +) + +p.add_argument( + "--warmup_count", + type=int, + default=0, + help="flag setting warmup count for clip and vae [>= 0].", +) + +p.add_argument( + "--clear_all", + default=False, + action=argparse.BooleanOptionalAction, + help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes", +) + +############################################################################## +### Web UI flags +############################################################################## + +p.add_argument( + "--progress_bar", + default=True, + action=argparse.BooleanOptionalAction, + help="flag for removing the pregress bar animation during image generation", +) + +p.add_argument( + "--share", + default=False, + action=argparse.BooleanOptionalAction, + help="flag for generating a public URL", +) + +p.add_argument( + "--server_port", + type=int, + default=8080, + help="flag for setting server port", +) + +############################################################################## +### SD model auto-annotation flags +############################################################################## + +p.add_argument( + "--annotation_output", + type=path_expand, + default="./", + help="Directory to save the annotated mlir file", +) + +p.add_argument( + "--annotation_model", + type=str, + default="unet", + help="Options are unet and vae.", +) + +p.add_argument( + "--use_winograd", + default=False, + action=argparse.BooleanOptionalAction, + help="Apply Winograd on selected conv ops.", +) + +args = p.parse_args() diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py new file mode 100644 index 00000000..7c24188d --- /dev/null +++ b/apps/stable_diffusion/src/utils/utils.py @@ -0,0 +1,353 @@ +import os +import torch +from shark.shark_inference import SharkInference +from shark.shark_importer import import_with_fx +from shark.iree_utils.vulkan_utils import ( + set_iree_vulkan_runtime_flags, + get_vulkan_target_triple, +) +from shark.iree_utils.gpu_utils import get_cuda_sm_cc +from .stable_args import args +from .resources import opt_flags +import sys +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + load_pipeline_from_original_stable_diffusion_ckpt, +) + + +def _compile_module(shark_module, model_name, extra_args=[]): + if args.load_vmfb or args.save_vmfb: + device = ( + args.device + if "://" not in args.device + else "-".join(args.device.split("://")) + ) + extended_name = "{}_{}".format(model_name, device) + vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb") + if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb: + print(f"loading existing vmfb from: {vmfb_path}") + shark_module.load_module(vmfb_path, extra_args=extra_args) + else: + if args.save_vmfb: + print("Saving to {}".format(vmfb_path)) + else: + print( + "No vmfb found. Compiling and saving to {}".format( + vmfb_path + ) + ) + path = shark_module.save_module( + os.getcwd(), extended_name, extra_args + ) + shark_module.load_module(path, extra_args=extra_args) + else: + shark_module.compile(extra_args) + return shark_module + + +# Downloads the model from shark_tank and returns the shark_module. +def get_shark_model(tank_url, model_name, extra_args=[]): + from shark.shark_downloader import download_model + from shark.parser import shark_args + + # Set local shark_tank cache directory. + shark_args.local_tank_cache = args.local_tank_cache + if "cuda" in args.device: + shark_args.enable_tf32 = True + + mlir_model, func_name, inputs, golden_out = download_model( + model_name, + tank_url=tank_url, + frontend="torch", + ) + shark_module = SharkInference( + mlir_model, device=args.device, mlir_dialect="linalg" + ) + return _compile_module(shark_module, model_name, extra_args) + + +# Converts the torch-module into a shark_module. +def compile_through_fx( + model, + inputs, + model_name, + is_f16=False, + f16_input_mask=None, + extra_args=[], +): + + mlir_module, func_name = import_with_fx( + model, inputs, is_f16, f16_input_mask + ) + shark_module = SharkInference( + mlir_module, + device=args.device, + mlir_dialect="linalg", + ) + + return _compile_module(shark_module, model_name, extra_args) + + +def set_iree_runtime_flags(): + + vulkan_runtime_flags = [ + f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}", + f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}", + ] + if args.enable_rgp: + vulkan_runtime_flags += [ + f"--enable_rgp=true", + f"--vulkan_debug_utils=true", + ] + set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) + + +def get_all_devices(driver_name): + """ + Inputs: driver_name + Returns a list of all the available devices for a given driver sorted by + the iree path names of the device as in --list_devices option in iree. + """ + from iree.runtime import get_driver + + driver = get_driver(driver_name) + device_list_src = driver.query_available_devices() + device_list_src.sort(key=lambda d: d["path"]) + return device_list_src + + +def get_device_mapping(driver, key_combination=3): + """This method ensures consistent device ordering when choosing + specific devices for execution + Args: + driver (str): execution driver (vulkan, cuda, rocm, etc) + key_combination (int, optional): choice for mapping value for device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Returns: + dict: map to possible device names user can input mapped to desired combination of name/path. + """ + from shark.iree_utils._common import iree_device_map + + driver = iree_device_map(driver) + device_list = get_all_devices(driver) + device_map = dict() + + def get_output_value(dev_dict): + if key_combination == 1: + return f"{driver}://{dev_dict['path']}" + if key_combination == 2: + return dev_dict["name"] + if key_combination == 3: + return (dev_dict["name"], f"{driver}://{dev_dict['path']}") + + # mapping driver name to default device (driver://0) + device_map[f"{driver}"] = get_output_value(device_list[0]) + for i, device in enumerate(device_list): + # mapping with index + device_map[f"{driver}://{i}"] = get_output_value(device) + # mapping with full path + device_map[f"{driver}://{device['path']}"] = get_output_value(device) + return device_map + + +def map_device_to_name_path(device, key_combination=3): + """Gives the appropriate device data (supported name/path) for user selected execution device + Args: + device (str): user + key_combination (int, optional): choice for mapping value for device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Raises: + ValueError: + Returns: + str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value + """ + driver = device.split("://")[0] + device_map = get_device_mapping(driver, key_combination) + try: + device_mapping = device_map[device] + except KeyError: + raise ValueError(f"Device '{device}' is not a valid device.") + return device_mapping + + +def set_init_device_flags(): + if "vulkan" in args.device: + # set runtime flags for vulkan. + set_iree_runtime_flags() + + # set triple flag to avoid multiple calls to get_vulkan_triple_flag + device_name, args.device = map_device_to_name_path(args.device) + if not args.iree_vulkan_target_triple: + triple = get_vulkan_target_triple(device_name) + if triple is not None: + args.iree_vulkan_target_triple = triple + print( + f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}." + ) + elif "cuda" in args.device: + args.device = "cuda" + elif "cpu" in args.device: + args.device = "cpu" + + # set max_length based on availability. + if args.hf_model_id in [ + "Linaqruf/anything-v3.0", + "wavymulder/Analog-Diffusion", + "dreamlike-art/dreamlike-diffusion-1.0", + ]: + args.max_length = 77 + elif args.hf_model_id == "prompthero/openjourney": + args.max_length = 64 + + # Use tuned models in the case of a specific setting. + if ( + args.hf_model_id + in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"] + or args.precision != "fp16" + ): + args.use_tuned = False + + elif ( + "vulkan" in args.device + and "rdna3" not in args.iree_vulkan_target_triple + ): + args.use_tuned = False + + elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]: + args.use_tuned = False + + elif args.use_base_vae and args.hf_model_id not in [ + "stabilityai/stable-diffusion-2-1-base", + "CompVis/stable-diffusion-v1-4", + ]: + args.use_tuned = False + + if args.use_tuned: + print(f"Using tuned models for {args.hf_model_id}/fp16/{args.device}.") + else: + print("Tuned models are currently not supported for this setting.") + + # set import_mlir to True for unuploaded models. + if args.hf_model_id not in [ + "Linaqruf/anything-v3.0", + "dreamlike-art/dreamlike-diffusion-1.0", + "prompthero/openjourney", + "wavymulder/Analog-Diffusion", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-2-1-base", + "CompVis/stable-diffusion-v1-4", + ]: + args.import_mlir = True + + if args.height != 512 or args.width != 512 or args.batch_size != 1: + args.import_mlir = True + + +# Utility to get list of devices available. +def get_available_devices(): + def get_devices_by_name(driver_name): + from shark.iree_utils._common import iree_device_map + + device_list = [] + try: + driver_name = iree_device_map(driver_name) + device_list_dict = get_all_devices(driver_name) + print(f"{driver_name} devices are available.") + except: + print(f"{driver_name} devices are not available.") + else: + for i, device in enumerate(device_list_dict): + device_list.append(f"{device['name']} => {driver_name}://{i}") + return device_list + + set_iree_runtime_flags() + + available_devices = [] + vulkan_devices = get_devices_by_name("vulkan") + available_devices.extend(vulkan_devices) + cuda_devices = get_devices_by_name("cuda") + available_devices.extend(cuda_devices) + available_devices.append("cpu") + return available_devices + + +def disk_space_check(path, lim=20): + from shutil import disk_usage + + du = disk_usage(path) + free = du.free / (1024 * 1024 * 1024) + if free <= lim: + print(f"[WARNING] Only {free:.2f}GB space available in {path}.") + + +def get_opt_flags(model, precision="fp16"): + iree_flags = [] + is_tuned = "tuned" if args.use_tuned else "untuned" + if len(args.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" + ) + + # Disable bindings fusion to work with moltenVK. + if sys.platform == "darwin": + iree_flags.append("-iree-stream-fuse-binding=false") + + if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]: + device = ( + args.device + if "://" not in args.device + else args.device.split("://")[0] + ) + if ( + device + not in opt_flags[model][is_tuned][precision][ + "specified_compilation_flags" + ] + ): + device = "default_device" + iree_flags += opt_flags[model][is_tuned][precision][ + "specified_compilation_flags" + ][device] + + return iree_flags + + +def preprocessCKPT(): + from pathlib import Path + + path = Path(args.ckpt_loc) + diffusers_path = path.parent.absolute() + diffusers_directory_name = path.stem + complete_path_to_diffusers = diffusers_path / diffusers_directory_name + complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) + print( + "Created directory : ", + diffusers_directory_name, + " at -> ", + diffusers_path, + ) + path_to_diffusers = complete_path_to_diffusers.as_posix() + from_safetensors = ( + True if args.ckpt_loc.lower().endswith(".safetensors") else False + ) + # EMA weights usually yield higher quality images for inference but non-EMA weights have + # been yielding better results in our case. + # TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA + # weight extraction or not. + extract_ema = False + print("Loading pipeline from original stable diffusion checkpoint") + pipe = load_pipeline_from_original_stable_diffusion_ckpt( + checkpoint_path=args.ckpt_loc, + extract_ema=extract_ema, + from_safetensors=from_safetensors, + ) + pipe.save_pretrained(path_to_diffusers) + print("Loading complete") + args.ckpt_loc = path_to_diffusers + print("Custom model path is : ", args.ckpt_loc) diff --git a/apps/stable_diffusion/web/css/sd_dark_theme.css b/apps/stable_diffusion/web/css/sd_dark_theme.css new file mode 100644 index 00000000..7a48098d --- /dev/null +++ b/apps/stable_diffusion/web/css/sd_dark_theme.css @@ -0,0 +1,67 @@ +.gradio-container { + background-color: black +} + +.container { + background-color: black !important; + padding-top: 20px !important; +} + +#ui_title { + padding: 10px !important; +} + +#top_logo { + background-color: transparent; + border-radius: 0 !important; + border: 0; +} + +#demo_title { + background-color: black; + border-radius: 0 !important; + border: 0; + padding-top: 50px; + padding-bottom: 0px; + width: 460px !important; +} + +#demo_title_outer { + border-radius: 0; +} + +#prompt_box_outer div:first-child { + border-radius: 0 !important +} + +#prompt_box textarea { + background-color: #1d1d1d !important +} + +#prompt_examples { + margin: 0 !important +} + +#prompt_examples svg { + display: none !important; +} + +.gr-sample-textbox { + border-radius: 1rem !important; + border-color: rgb(31, 41, 55) !important; + border-width: 2px !important; +} + +#ui_body { + background-color: #111111 !important; + padding: 10px !important; + border-radius: 0.5em !important; +} + +#img_result+div { + display: none !important; +} + +footer { + display: none !important; +} diff --git a/apps/stable_diffusion/web/gradio/img2img_ui.py b/apps/stable_diffusion/web/gradio/img2img_ui.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/stable_diffusion/web/gradio/txt2img_ui.py b/apps/stable_diffusion/web/gradio/txt2img_ui.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py new file mode 100644 index 00000000..6c9009f5 --- /dev/null +++ b/apps/stable_diffusion/web/index.py @@ -0,0 +1,249 @@ +import os +import sys +from pathlib import Path + +if "AMD_ENABLE_LLPC" not in os.environ: + os.environ["AMD_ENABLE_LLPC"] = "1" + +if sys.platform == "darwin": + os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib" + + +def resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr( + sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) + ) + return os.path.join(base_path, relative_path) + + +import gradio as gr +from PIL import Image +from apps.stable_diffusion.src import ( + prompt_examples, + args, + get_available_devices, +) +from apps.stable_diffusion.scripts import txt2img_inf + +nodlogo_loc = resource_path("logos/nod-logo.png") +sdlogo_loc = resource_path("logos/sd-demo-logo.png") + + +demo_css = resource_path("css/sd_dark_theme.css") + + +with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web: + + with gr.Row(elem_id="ui_title"): + nod_logo = Image.open(nodlogo_loc) + logo2 = Image.open(sdlogo_loc) + with gr.Row(): + with gr.Column(scale=1, elem_id="demo_title_outer"): + gr.Image( + value=nod_logo, + show_label=False, + interactive=False, + elem_id="top_logo", + ).style(width=150, height=100) + with gr.Column(scale=5, elem_id="demo_title_outer"): + gr.Image( + value=logo2, + show_label=False, + interactive=False, + elem_id="demo_title", + ).style(width=150, height=100) + + with gr.Row(elem_id="ui_body"): + + with gr.Row(): + with gr.Column(scale=1, min_width=600): + with gr.Row(): + with gr.Group(): + model_id = gr.Dropdown( + label="Model ID", + value="stabilityai/stable-diffusion-2-1-base", + choices=[ + "Linaqruf/anything-v3.0", + "prompthero/openjourney", + "wavymulder/Analog-Diffusion", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-2-1-base", + "CompVis/stable-diffusion-v1-4", + ], + ) + custom_model_id = gr.Textbox( + placeholder="check here: https://huggingface.co/models eg. runwayml/stable-diffusion-v1-5", + value="", + label="HuggingFace Model ID", + ) + with gr.Group(): + ckpt_loc = gr.File( + label="Upload checkpoint", + file_types=[".ckpt", ".safetensors"], + ) + + with gr.Group(elem_id="prompt_box_outer"): + prompt = gr.Textbox( + label="Prompt", + value="cyberpunk forest by Salvador Dali", + lines=1, + elem_id="prompt_box", + ) + negative_prompt = gr.Textbox( + label="Negative Prompt", + value="trees, green", + lines=1, + elem_id="prompt_box", + ) + with gr.Accordion(label="Advance Options", open=False): + with gr.Row(): + scheduler = gr.Dropdown( + label="Scheduler", + value="SharkEulerDiscrete", + choices=[ + "DDIM", + "PNDM", + "LMSDiscrete", + "DPMSolverMultistep", + "EulerDiscrete", + "EulerAncestralDiscrete", + "SharkEulerDiscrete", + ], + ) + batch_size = gr.Slider( + 1, 4, value=1, step=1, label="Number of Images" + ) + with gr.Row(): + height = gr.Slider( + 384, 786, value=512, step=8, label="Height" + ) + width = gr.Slider( + 384, 786, value=512, step=8, label="Width" + ) + precision = gr.Radio( + label="Precision", + value="fp16", + choices=[ + "fp16", + "fp32", + ], + visible=False, + ) + max_length = gr.Radio( + label="Max Length", + value=64, + choices=[ + 64, + 77, + ], + visible=False, + ) + with gr.Row(): + steps = gr.Slider( + 1, 100, value=50, step=1, label="Steps" + ) + guidance_scale = gr.Slider( + 0, + 50, + value=7.5, + step=0.1, + label="CFG Scale", + ) + with gr.Row(): + seed = gr.Number(value=-1, precision=0, label="Seed") + available_devices = get_available_devices() + device = gr.Dropdown( + label="Device", + value=available_devices[0], + choices=available_devices, + ) + with gr.Row(): + random_seed = gr.Button("Randomize Seed") + random_seed.click( + None, + inputs=[], + outputs=[seed], + _js="() => Math.floor(Math.random() * 4294967295)", + ) + stable_diffusion = gr.Button("Generate Image") + with gr.Accordion(label="Prompt Examples!", open=False): + ex = gr.Examples( + examples=prompt_examples, + inputs=prompt, + cache_examples=False, + elem_id="prompt_examples", + ) + + with gr.Column(scale=1, min_width=600): + with gr.Group(): + gallery = gr.Gallery( + label="Generated images", + show_label=False, + elem_id="gallery", + ).style(grid=[2], height="auto") + std_output = gr.Textbox( + value="Nothing to show.", + lines=4, + show_label=False, + ) + output_dir = args.output_dir if args.output_dir else Path.cwd() + output_dir = Path(output_dir, "generated_imgs") + output_loc = gr.Textbox( + label="Saving Images at", + value=output_dir, + interactive=False, + ) + + prompt.submit( + txt2img_inf, + inputs=[ + prompt, + negative_prompt, + height, + width, + steps, + guidance_scale, + seed, + batch_size, + scheduler, + model_id, + custom_model_id, + ckpt_loc, + precision, + device, + max_length, + ], + outputs=[gallery, std_output], + show_progress=args.progress_bar, + ) + stable_diffusion.click( + txt2img_inf, + inputs=[ + prompt, + negative_prompt, + height, + width, + steps, + guidance_scale, + seed, + batch_size, + scheduler, + model_id, + custom_model_id, + ckpt_loc, + precision, + device, + max_length, + ], + outputs=[gallery, std_output], + show_progress=args.progress_bar, + ) + +shark_web.queue() +shark_web.launch( + share=args.share, + inbrowser=True, + server_name="0.0.0.0", + server_port=args.server_port, +) diff --git a/apps/stable_diffusion/web/logos/Nod_logo.png b/apps/stable_diffusion/web/logos/Nod_logo.png new file mode 100644 index 00000000..85221e69 Binary files /dev/null and b/apps/stable_diffusion/web/logos/Nod_logo.png differ diff --git a/apps/stable_diffusion/web/logos/nod-logo.png b/apps/stable_diffusion/web/logos/nod-logo.png new file mode 100644 index 00000000..4727e15a Binary files /dev/null and b/apps/stable_diffusion/web/logos/nod-logo.png differ diff --git a/apps/stable_diffusion/web/logos/sd-demo-logo.png b/apps/stable_diffusion/web/logos/sd-demo-logo.png new file mode 100644 index 00000000..afaf8eb8 Binary files /dev/null and b/apps/stable_diffusion/web/logos/sd-demo-logo.png differ diff --git a/shark/examples/shark_inference/stable_diffusion/resources/model_config.json b/shark/examples/shark_inference/stable_diffusion/resources/model_config.json index 0e241795..207afe4e 100644 --- a/shark/examples/shark_inference/stable_diffusion/resources/model_config.json +++ b/shark/examples/shark_inference/stable_diffusion/resources/model_config.json @@ -1,21 +1,21 @@ [ - { - "stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4", - "stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base", - "stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1", - "anythingv3/v1_4":"Linaqruf/anything-v3.0", - "analogdiffusion/v1_4":"wavymulder/Analog-Diffusion", - "openjourney/v1_4":"prompthero/openjourney", - "dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0" - }, - { - "stablediffusion/fp16":"fp16", - "stablediffusion/fp32":"main", - "anythingv3/fp16":"diffusers", - "anythingv3/fp32":"diffusers", - "analogdiffusion/fp16":"main", - "analogdiffusion/fp32":"main", - "openjourney/fp16":"main", - "openjourney/fp32":"main" - } - ] \ No newline at end of file + { + "stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4", + "stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base", + "stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1", + "anythingv3/v1_4":"Linaqruf/anything-v3.0", + "analogdiffusion/v1_4":"wavymulder/Analog-Diffusion", + "openjourney/v1_4":"prompthero/openjourney", + "dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0" + }, + { + "stablediffusion/fp16":"fp16", + "stablediffusion/fp32":"main", + "anythingv3/fp16":"diffusers", + "anythingv3/fp32":"diffusers", + "analogdiffusion/fp16":"main", + "analogdiffusion/fp32":"main", + "openjourney/fp16":"main", + "openjourney/fp32":"main" + } +] diff --git a/shark/examples/shark_inference/stable_diffusion/stable_args.py b/shark/examples/shark_inference/stable_diffusion/stable_args.py index d9b6f7d8..0b78f3fb 100644 --- a/shark/examples/shark_inference/stable_diffusion/stable_args.py +++ b/shark/examples/shark_inference/stable_diffusion/stable_args.py @@ -294,6 +294,20 @@ p.add_argument( help="flag for removing the pregress bar animation during image generation", ) +p.add_argument( + "--share", + default=False, + action=argparse.BooleanOptionalAction, + help="flag for generating a public URL", +) + +p.add_argument( + "--server_port", + type=int, + default=8080, + help="flag for setting server port", +) + ############################################################################## ### SD model auto-annotation flags ############################################################################## diff --git a/shark/examples/shark_inference/stable_diffusion/utils.py b/shark/examples/shark_inference/stable_diffusion/utils.py index 45a3ce02..de2652d5 100644 --- a/shark/examples/shark_inference/stable_diffusion/utils.py +++ b/shark/examples/shark_inference/stable_diffusion/utils.py @@ -321,7 +321,7 @@ def get_available_devices(): print(f"{driver_name} devices are not available.") else: for i, device in enumerate(device_list_dict): - device_list.append(f"{driver_name}://{i} => {device['name']}") + device_list.append(f"{device['name']} => {driver_name}://{i}") return device_list set_iree_runtime_flags()