From 88f8718635f165c82b8aea1dbce0196fc0a2259e Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Wed, 2 Nov 2022 18:37:39 +0530 Subject: [PATCH] [WEB] Load prompts from json The prompt examples will now be loaded from a json file `prompts.json`. Signed-Off-by: Gaurav Shukla --- web/README.md | 9 +++++-- web/index.py | 38 ++++++++++++++++----------- web/models/stable_diffusion/main.py | 40 ++++++++++++++++++----------- web/prompts.json | 9 +++++++ 4 files changed, 64 insertions(+), 32 deletions(-) create mode 100644 web/prompts.json diff --git a/web/README.md b/web/README.md index c95cc556..4f4091e0 100644 --- a/web/README.md +++ b/web/README.md @@ -1,5 +1,6 @@ In order to launch SHARK-web, from the root SHARK directory, run: +## Linux ```shell IMPORTER=1 ./setup_venv.sh source shark.venv/bin/activate @@ -7,5 +8,9 @@ cd web python index.py ``` -This will launch a gradio server with a public URL like: -Running on public URL: https://xxxxx.gradio.app +## Windows +```shell +./setup_venv.ps1 +cd web +python index.py --local_tank_cache= +``` diff --git a/web/index.py b/web/index.py index a2c09135..bf58d80d 100644 --- a/web/index.py +++ b/web/index.py @@ -5,6 +5,8 @@ from models.stable_diffusion.main import stable_diff_inf # from models.diffusion.v_diffusion import vdiff_inf import gradio as gr from PIL import Image +import json +import os def debug_event(debug): @@ -149,16 +151,15 @@ with gr.Blocks() as shark_web: iree_vulkan_target_triple ) = ( live_preview - ) = debug = stable_diffusion = generated_img = std_output = None - examples = [ - ["A high tech solarpunk utopia in the Amazon rainforest"], - ["A pikachu fine dining with a view to the Eiffel Tower"], - ["A mecha robot in a favela in expressionist style"], - ["an insect robot preparing a delicious meal"], - [ - "A small cabin on top of a snowy mountain in the style of Disney, artstation" - ], - ] + ) = ( + debug + ) = save_img = stable_diffusion = generated_img = std_output = None + # load prompts. + prompt_examples = [] + prompt_loc = "./prompts.json" + if os.path.exists(prompt_loc): + fopen = open("./prompts.json") + prompt_examples = json.load(fopen) with gr.Row(): with gr.Column(scale=1, min_width=600): @@ -169,7 +170,7 @@ with gr.Blocks() as shark_web: lines=5, ) ex = gr.Examples( - examples=examples, + examples=prompt_examples, inputs=prompt, cache_examples=False, ) @@ -195,7 +196,12 @@ with gr.Blocks() as shark_web: 1, 100, value=50, step=1, label="Steps" ) guidance = gr.Slider( - 0, 50, value=7.5, step=0.1, label="Guidance Scale" + 0, + 50, + value=7.5, + step=0.1, + label="Guidance Scale", + interactive=False, ) with gr.Row(): height = gr.Slider( @@ -232,7 +238,7 @@ with gr.Blocks() as shark_web: with gr.Row(): precision = gr.Radio( label="Precision", - value="fp32", + value="fp16", choices=["fp16", "fp32"], ) seed = gr.Textbox( @@ -240,10 +246,11 @@ with gr.Blocks() as shark_web: ) with gr.Row(): cache = gr.Checkbox(label="Cache", value=True) - debug = gr.Checkbox(label="DEBUG", value=False) live_preview = gr.Checkbox( label="Live Preview", value=False ) + debug = gr.Checkbox(label="DEBUG", value=False) + save_img = gr.Checkbox(label="Save Image", value=False) iree_vulkan_target_triple = gr.Textbox( value="", max_lines=1, @@ -256,7 +263,7 @@ with gr.Blocks() as shark_web: std_output = gr.Textbox( label="Std Output", value="Nothing.", - lines=10, + lines=5, visible=False, ) @@ -283,6 +290,7 @@ with gr.Blocks() as shark_web: cache, iree_vulkan_target_triple, live_preview, + save_img, ], outputs=[generated_img, std_output], ) diff --git a/web/models/stable_diffusion/main.py b/web/models/stable_diffusion/main.py index a4693d5b..dbfc8b2f 100644 --- a/web/models/stable_diffusion/main.py +++ b/web/models/stable_diffusion/main.py @@ -23,7 +23,6 @@ UNET_FP32 = "unet_fp32" TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn" UNET_FP16_TUNED = "unet_fp16_tunedv2" -IREE_EXTRA_ARGS = [] args = None @@ -44,6 +43,7 @@ class Arguments: cache: bool, iree_vulkan_target_triple: str, live_preview: bool, + save_img: bool, import_mlir: bool = False, max_length: int = 77, use_tuned: bool = True, @@ -62,6 +62,7 @@ class Arguments: self.cache = cache self.iree_vulkan_target_triple = iree_vulkan_target_triple self.live_preview = live_preview + self.save_img = save_img self.import_mlir = import_mlir self.max_length = max_length self.use_tuned = use_tuned @@ -69,9 +70,9 @@ class Arguments: def get_models(): - global IREE_EXTRA_ARGS global args + IREE_EXTRA_ARGS = [] if args.precision == "fp16": IREE_EXTRA_ARGS += [ "--iree-flow-enable-padding-linalg-ops", @@ -182,9 +183,10 @@ args = Arguments( seed=42, precision="fp16", device="vulkan", - cache=True, + cache=False, iree_vulkan_target_triple="", live_preview=False, + save_img=False, import_mlir=False, max_length=77, use_tuned=True, @@ -193,6 +195,9 @@ cache_obj["vae_fp16_vulkan"], cache_obj["unet_fp16_vulkan"] = get_models() args.precision = "fp32" cache_obj["vae_fp32_vulkan"], cache_obj["unet_fp32_vulkan"] = get_models() +output_dir = "./stored_results/stable_diffusion" +os.makedirs(output_dir, exist_ok=True) + def stable_diff_inf( prompt: str, @@ -209,14 +214,15 @@ def stable_diff_inf( cache: bool, iree_vulkan_target_triple: str, live_preview: bool, + save_img: bool, ): - global IREE_EXTRA_ARGS global args global schedulers global cache_obj + global output_dir - output_loc = f"stored_results/stable_diffusion/{time.time()}_{int(steps)}_{precision}_{device}.jpg" + start = time.time() # set seed value if seed == "": @@ -224,7 +230,9 @@ def stable_diff_inf( else: try: seed = int(seed) - except ValueError: + if seed < 0 or seed > 10000: + seed = hash(seed) + except (ValueError, OverflowError) as error: seed = hash(seed) scheduler = schedulers[scheduler] @@ -243,12 +251,9 @@ def stable_diff_inf( cache, iree_vulkan_target_triple, live_preview, + save_img, ) dtype = torch.float32 if args.precision == "fp32" else torch.half - if len(args.iree_vulkan_target_triple) > 0: - IREE_EXTRA_ARGS.append( - f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" - ) num_inference_steps = int(args.steps) # Number of denoising steps generator = torch.manual_seed( args.seed @@ -310,7 +315,7 @@ def stable_diff_inf( text_output = "" for i, t in tqdm(enumerate(scheduler.timesteps)): - text_output = text_output + f"\ni = {i} t = {t} " + text_output += f"\n Iteration = {i} | Timestep = {t} | " step_start = time.time() timestep = torch.tensor([t]).to(dtype).detach().numpy() latents_numpy = latents.detach().numpy() @@ -323,7 +328,7 @@ def stable_diff_inf( step_time = time.time() - step_start avg_ms += step_time step_ms = int((step_time) * 1000) - text_output = text_output + f"time={step_ms}ms" + text_output += f"Time = {step_ms}ms." latents = scheduler.step(noise_pred, i, latents)["prev_sample"] if live_preview and i % 5 == 0: @@ -348,8 +353,13 @@ def stable_diff_inf( out_img = pil_images[0] avg_ms = 1000 * avg_ms / args.steps - text_output = text_output + f"\nAverage step time: {avg_ms}ms/it" + text_output += f"\n\nAverage step time: {avg_ms}ms/it" - # save the output image with the prompt name. - out_img.save(os.path.join(output_loc)) + total_time = time.time() - start + text_output += f"\n\nTotal image generation time: {total_time}sec" + + if args.save_img: + # save outputs. + output_loc = f"{output_dir}/{time.time()}_{int(args.steps)}_{args.precision}_{args.device}.jpg" + out_img.save(os.path.join(output_loc)) yield out_img, text_output diff --git a/web/prompts.json b/web/prompts.json new file mode 100644 index 00000000..d520409f --- /dev/null +++ b/web/prompts.json @@ -0,0 +1,9 @@ +[["A high tech solarpunk utopia in the Amazon rainforest"], +["A pikachu fine dining with a view to the Eiffel Tower"], +["A mecha robot in a favela in expressionist style"], +["an insect robot preparing a delicious meal"], +["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"], +["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"], +["A beautiful castle beside a waterfall in the woods, by Josef Thoma, matte painting, trending on artstation HQ"], +["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"], +["A small cabin on top of a snowy mountain in the style of Disney, artstation"]]