diff --git a/setup_venv.ps1 b/setup_venv.ps1 index a4b9bdc8..20ce4527 100644 --- a/setup_venv.ps1 +++ b/setup_venv.ps1 @@ -35,6 +35,6 @@ pip install --pre torch-mlir torch torchvision --extra-index-url https://downloa pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime Write-Host "Building SHARK..." pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html -pip install diffusers transformers scipy +pip install diffusers transformers scipy gradio Write-Host "Build and installation completed successfully" Write-Host "Source your venv with ./shark.venv/Scripts/activate" diff --git a/web/index.py b/web/index.py index 25e18df3..2f67e712 100644 --- a/web/index.py +++ b/web/index.py @@ -16,12 +16,20 @@ with gr.Blocks() as shark_web: with gr.Row(): with gr.Group(): with gr.Column(scale=1): - img = Image.open("./Nod_logo.png") - gr.Image(value=img, show_label=False, interactive=False).style( - height=80, width=150 - ) + nod_logo = Image.open("./logos/Nod_logo.png") + gr.Image( + value=nod_logo, show_label=False, interactive=False + ).style(height=80, width=150) with gr.Column(scale=1): - gr.Label(value="Shark Models Demo.") + logo2 = Image.open("./logos/other_logo.png") + gr.Image( + value=logo2, + show_label=False, + interactive=False, + visible=False, + ).style(height=80, width=150) + with gr.Column(scale=1): + gr.Label(value="Ultra fast Stable Diffusion") with gr.Tabs(): # with gr.TabItem("ResNet50"): @@ -136,9 +144,7 @@ with gr.Blocks() as shark_web: ) = ( device ) = ( - load_vmfb - ) = ( - save_vmfb + cache ) = ( iree_vulkan_target_triple ) = ( @@ -160,6 +166,7 @@ with gr.Blocks() as shark_web: prompt = gr.Textbox( label="Prompt", value="a photograph of an astronaut riding a horse", + lines=5, ) ex = gr.Examples( examples=examples, @@ -219,11 +226,10 @@ with gr.Blocks() as shark_web: value="42", max_lines=1, label="Seed" ) with gr.Row(): - load_vmfb = gr.Checkbox(label="Load vmfb", value=True) - save_vmfb = gr.Checkbox(label="Save vmfb", value=False) + cache = gr.Checkbox(label="Cache", value=True) debug = gr.Checkbox(label="DEBUG", value=False) live_preview = gr.Checkbox( - label="live preview", value=False + label="Live Preview", value=False ) iree_vulkan_target_triple = gr.Textbox( value="", @@ -231,7 +237,7 @@ with gr.Blocks() as shark_web: label="IREE VULKAN TARGET TRIPLE", visible=False, ) - stable_diffusion = gr.Button("Generate image from prompt") + stable_diffusion = gr.Button("Generate Image") with gr.Column(scale=1, min_width=600): generated_img = gr.Image(type="pil", shape=(100, 100)) std_output = gr.Textbox( @@ -261,8 +267,7 @@ with gr.Blocks() as shark_web: seed, precision, device, - load_vmfb, - save_vmfb, + cache, iree_vulkan_target_triple, live_preview, ], @@ -270,4 +275,4 @@ with gr.Blocks() as shark_web: ) shark_web.queue() -shark_web.launch(share=True, server_port=8080, enable_queue=True) +shark_web.launch(server_port=8080, enable_queue=True) diff --git a/web/Nod_logo.png b/web/logos/Nod_logo.png similarity index 100% rename from web/Nod_logo.png rename to web/logos/Nod_logo.png diff --git a/web/logos/other_logo.png b/web/logos/other_logo.png new file mode 100644 index 00000000..85221e69 Binary files /dev/null and b/web/logos/other_logo.png differ diff --git a/web/models/stable_diffusion/main.py b/web/models/stable_diffusion/main.py index 0c7d0ee9..fe5a92a0 100644 --- a/web/models/stable_diffusion/main.py +++ b/web/models/stable_diffusion/main.py @@ -22,7 +22,6 @@ UNET_FP32 = "unet_fp32" IREE_EXTRA_ARGS = [] args = None -DEBUG = False class Arguments: @@ -39,8 +38,7 @@ class Arguments: seed: int, precision: str, device: str, - load_vmfb: bool, - save_vmfb: bool, + cache: bool, iree_vulkan_target_triple: str, live_preview: bool, import_mlir: bool = False, @@ -57,8 +55,7 @@ class Arguments: self.seed = seed self.precision = precision self.device = device - self.load_vmfb = load_vmfb - self.save_vmfb = save_vmfb + self.cache = cache self.iree_vulkan_target_triple = iree_vulkan_target_triple self.live_preview = live_preview self.import_mlir = import_mlir @@ -101,6 +98,37 @@ def get_models(): return None, None +schedulers = dict() +# set scheduler value +schedulers["PNDM"] = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, +) +schedulers["LMS"] = LMSDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, +) +schedulers["DDIM"] = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, +) + +cache_obj = dict() +cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14" +) +cache_obj["text_encoder"] = CLIPTextModel.from_pretrained( + "openai/clip-vit-large-patch14" +) + + def stable_diff_inf( prompt: str, scheduler: str, @@ -113,21 +141,17 @@ def stable_diff_inf( seed: str, precision: str, device: str, - load_vmfb: bool, - save_vmfb: bool, + cache: bool, iree_vulkan_target_triple: str, live_preview: bool, ): global IREE_EXTRA_ARGS global args - global DEBUG + global schedulers + global cache_obj - output_loc = f"stored_results/stable_diffusion/{prompt}_{int(steps)}_{precision}_{device}.jpg" - DEBUG = False - log_write = open(r"logs/stable_diffusion_log.txt", "w") - if log_write: - DEBUG = True + output_loc = f"stored_results/stable_diffusion/{time.time()}_{int(steps)}_{precision}_{device}.jpg" # set seed value if seed == "": @@ -138,34 +162,7 @@ def stable_diff_inf( except ValueError: seed = hash(seed) - # set scheduler value - if scheduler == "PNDM": - scheduler = PNDMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - ) - elif scheduler == "LMS": - scheduler = LMSDiscreteScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - ) - elif scheduler == "DDIM": - scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - ) - else: - raise Exception( - f"Does not support scheduler with name {args.scheduler}." - ) - + scheduler = schedulers[scheduler] args = Arguments( prompt, scheduler, @@ -178,8 +175,7 @@ def stable_diff_inf( seed, precision, device, - load_vmfb, - save_vmfb, + cache, iree_vulkan_target_triple, live_preview, ) @@ -194,11 +190,8 @@ def stable_diff_inf( ) # Seed generator to create the inital latent noise vae, unet = get_models() - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - text_encoder = CLIPTextModel.from_pretrained( - "openai/clip-vit-large-patch14" - ) - + tokenizer = cache_obj["tokenizer"] + text_encoder = cache_obj["text_encoder"] text_input = tokenizer( [args.prompt] * batch_size, padding="max_length", @@ -233,10 +226,10 @@ def stable_diff_inf( avg_ms = 0 out_img = None + text_output = "" for i, t in tqdm(enumerate(scheduler.timesteps)): - if DEBUG: - log_write.write(f"\ni = {i} t = {t} ") + text_output = text_output + f"\ni = {i} t = {t} " step_start = time.time() timestep = torch.tensor([t]).to(dtype).detach().numpy() latents_numpy = latents.detach().numpy() @@ -249,8 +242,7 @@ def stable_diff_inf( step_time = time.time() - step_start avg_ms += step_time step_ms = int((step_time) * 1000) - if DEBUG: - log_write.write(f"time={step_ms}ms") + text_output = text_output + f"time={step_ms}ms" latents = scheduler.step(noise_pred, i, latents)["prev_sample"] if live_preview: @@ -263,7 +255,7 @@ def stable_diff_inf( images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] out_img = pil_images[0] - yield out_img, "" + yield out_img, text_output # scale and decode the image latents with vae if not live_preview: @@ -277,14 +269,8 @@ def stable_diff_inf( out_img = pil_images[0] avg_ms = 1000 * avg_ms / args.steps - if DEBUG: - log_write.write(f"\nAverage step time: {avg_ms}ms/it") + text_output = text_output + f"\nAverage step time: {avg_ms}ms/it" # save the output image with the prompt name. out_img.save(os.path.join(output_loc)) - log_write.close() - - std_output = "" - with open(r"logs/stable_diffusion_log.txt", "r") as log_read: - std_output = log_read.read() - yield out_img, std_output + yield out_img, text_output diff --git a/web/models/stable_diffusion/utils.py b/web/models/stable_diffusion/utils.py index b6c54369..71283fa7 100644 --- a/web/models/stable_diffusion/utils.py +++ b/web/models/stable_diffusion/utils.py @@ -7,27 +7,16 @@ import os def _compile_module(args, shark_module, model_name, extra_args=[]): - if args.load_vmfb or args.save_vmfb: - extended_name = "{}_{}".format(model_name, args.device) + extended_name = "{}_{}".format(model_name, args.device) + if args.cache: vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb") - if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb: + if os.path.isfile(vmfb_path): print("Loading flatbuffer from {}".format(vmfb_path)) shark_module.load_module(vmfb_path) - else: - if args.save_vmfb: - print("Saving to {}".format(vmfb_path)) - else: - print( - "No vmfb found. Compiling and saving to {}".format( - vmfb_path - ) - ) - path = shark_module.save_module( - os.getcwd(), extended_name, extra_args - ) - shark_module.load_module(path) - else: - shark_module.compile(extra_args) + return shark_module + print("No vmfb found. Compiling and saving to {}".format(vmfb_path)) + path = shark_module.save_module(os.getcwd(), extended_name, extra_args) + shark_module.load_module(path) return shark_module