[WEB] Load prompts from json

The prompt examples will now be loaded from a json file `prompts.json`.

Signed-Off-by: Gaurav Shukla
This commit is contained in:
Gaurav Shukla
2022-11-02 18:37:39 +05:30
parent a081733a42
commit 88f8718635
4 changed files with 64 additions and 32 deletions

View File

@@ -1,5 +1,6 @@
In order to launch SHARK-web, from the root SHARK directory, run:
## Linux
```shell
IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
@@ -7,5 +8,9 @@ cd web
python index.py
```
This will launch a gradio server with a public URL like:
Running on public URL: https://xxxxx.gradio.app
## Windows
```shell
./setup_venv.ps1
cd web
python index.py --local_tank_cache=<current_working_dir>
```

View File

@@ -5,6 +5,8 @@ from models.stable_diffusion.main import stable_diff_inf
# from models.diffusion.v_diffusion import vdiff_inf
import gradio as gr
from PIL import Image
import json
import os
def debug_event(debug):
@@ -149,16 +151,15 @@ with gr.Blocks() as shark_web:
iree_vulkan_target_triple
) = (
live_preview
) = debug = stable_diffusion = generated_img = std_output = None
examples = [
["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
[
"A small cabin on top of a snowy mountain in the style of Disney, artstation"
],
]
) = (
debug
) = save_img = stable_diffusion = generated_img = std_output = None
# load prompts.
prompt_examples = []
prompt_loc = "./prompts.json"
if os.path.exists(prompt_loc):
fopen = open("./prompts.json")
prompt_examples = json.load(fopen)
with gr.Row():
with gr.Column(scale=1, min_width=600):
@@ -169,7 +170,7 @@ with gr.Blocks() as shark_web:
lines=5,
)
ex = gr.Examples(
examples=examples,
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
)
@@ -195,7 +196,12 @@ with gr.Blocks() as shark_web:
1, 100, value=50, step=1, label="Steps"
)
guidance = gr.Slider(
0, 50, value=7.5, step=0.1, label="Guidance Scale"
0,
50,
value=7.5,
step=0.1,
label="Guidance Scale",
interactive=False,
)
with gr.Row():
height = gr.Slider(
@@ -232,7 +238,7 @@ with gr.Blocks() as shark_web:
with gr.Row():
precision = gr.Radio(
label="Precision",
value="fp32",
value="fp16",
choices=["fp16", "fp32"],
)
seed = gr.Textbox(
@@ -240,10 +246,11 @@ with gr.Blocks() as shark_web:
)
with gr.Row():
cache = gr.Checkbox(label="Cache", value=True)
debug = gr.Checkbox(label="DEBUG", value=False)
live_preview = gr.Checkbox(
label="Live Preview", value=False
)
debug = gr.Checkbox(label="DEBUG", value=False)
save_img = gr.Checkbox(label="Save Image", value=False)
iree_vulkan_target_triple = gr.Textbox(
value="",
max_lines=1,
@@ -256,7 +263,7 @@ with gr.Blocks() as shark_web:
std_output = gr.Textbox(
label="Std Output",
value="Nothing.",
lines=10,
lines=5,
visible=False,
)
@@ -283,6 +290,7 @@ with gr.Blocks() as shark_web:
cache,
iree_vulkan_target_triple,
live_preview,
save_img,
],
outputs=[generated_img, std_output],
)

View File

@@ -23,7 +23,6 @@ UNET_FP32 = "unet_fp32"
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
UNET_FP16_TUNED = "unet_fp16_tunedv2"
IREE_EXTRA_ARGS = []
args = None
@@ -44,6 +43,7 @@ class Arguments:
cache: bool,
iree_vulkan_target_triple: str,
live_preview: bool,
save_img: bool,
import_mlir: bool = False,
max_length: int = 77,
use_tuned: bool = True,
@@ -62,6 +62,7 @@ class Arguments:
self.cache = cache
self.iree_vulkan_target_triple = iree_vulkan_target_triple
self.live_preview = live_preview
self.save_img = save_img
self.import_mlir = import_mlir
self.max_length = max_length
self.use_tuned = use_tuned
@@ -69,9 +70,9 @@ class Arguments:
def get_models():
global IREE_EXTRA_ARGS
global args
IREE_EXTRA_ARGS = []
if args.precision == "fp16":
IREE_EXTRA_ARGS += [
"--iree-flow-enable-padding-linalg-ops",
@@ -182,9 +183,10 @@ args = Arguments(
seed=42,
precision="fp16",
device="vulkan",
cache=True,
cache=False,
iree_vulkan_target_triple="",
live_preview=False,
save_img=False,
import_mlir=False,
max_length=77,
use_tuned=True,
@@ -193,6 +195,9 @@ cache_obj["vae_fp16_vulkan"], cache_obj["unet_fp16_vulkan"] = get_models()
args.precision = "fp32"
cache_obj["vae_fp32_vulkan"], cache_obj["unet_fp32_vulkan"] = get_models()
output_dir = "./stored_results/stable_diffusion"
os.makedirs(output_dir, exist_ok=True)
def stable_diff_inf(
prompt: str,
@@ -209,14 +214,15 @@ def stable_diff_inf(
cache: bool,
iree_vulkan_target_triple: str,
live_preview: bool,
save_img: bool,
):
global IREE_EXTRA_ARGS
global args
global schedulers
global cache_obj
global output_dir
output_loc = f"stored_results/stable_diffusion/{time.time()}_{int(steps)}_{precision}_{device}.jpg"
start = time.time()
# set seed value
if seed == "":
@@ -224,7 +230,9 @@ def stable_diff_inf(
else:
try:
seed = int(seed)
except ValueError:
if seed < 0 or seed > 10000:
seed = hash(seed)
except (ValueError, OverflowError) as error:
seed = hash(seed)
scheduler = schedulers[scheduler]
@@ -243,12 +251,9 @@ def stable_diff_inf(
cache,
iree_vulkan_target_triple,
live_preview,
save_img,
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
if len(args.iree_vulkan_target_triple) > 0:
IREE_EXTRA_ARGS.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
num_inference_steps = int(args.steps) # Number of denoising steps
generator = torch.manual_seed(
args.seed
@@ -310,7 +315,7 @@ def stable_diff_inf(
text_output = ""
for i, t in tqdm(enumerate(scheduler.timesteps)):
text_output = text_output + f"\ni = {i} t = {t} "
text_output += f"\n Iteration = {i} | Timestep = {t} | "
step_start = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latents_numpy = latents.detach().numpy()
@@ -323,7 +328,7 @@ def stable_diff_inf(
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
text_output = text_output + f"time={step_ms}ms"
text_output += f"Time = {step_ms}ms."
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
if live_preview and i % 5 == 0:
@@ -348,8 +353,13 @@ def stable_diff_inf(
out_img = pil_images[0]
avg_ms = 1000 * avg_ms / args.steps
text_output = text_output + f"\nAverage step time: {avg_ms}ms/it"
text_output += f"\n\nAverage step time: {avg_ms}ms/it"
# save the output image with the prompt name.
out_img.save(os.path.join(output_loc))
total_time = time.time() - start
text_output += f"\n\nTotal image generation time: {total_time}sec"
if args.save_img:
# save outputs.
output_loc = f"{output_dir}/{time.time()}_{int(args.steps)}_{args.precision}_{args.device}.jpg"
out_img.save(os.path.join(output_loc))
yield out_img, text_output

9
web/prompts.json Normal file
View File

@@ -0,0 +1,9 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful castle beside a waterfall in the woods, by Josef Thoma, matte painting, trending on artstation HQ"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["A small cabin on top of a snowy mountain in the style of Disney, artstation"]]