mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
[WEB] Load prompts from json
The prompt examples will now be loaded from a json file `prompts.json`. Signed-Off-by: Gaurav Shukla
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
In order to launch SHARK-web, from the root SHARK directory, run:
|
||||
|
||||
## Linux
|
||||
```shell
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
@@ -7,5 +8,9 @@ cd web
|
||||
python index.py
|
||||
```
|
||||
|
||||
This will launch a gradio server with a public URL like:
|
||||
Running on public URL: https://xxxxx.gradio.app
|
||||
## Windows
|
||||
```shell
|
||||
./setup_venv.ps1
|
||||
cd web
|
||||
python index.py --local_tank_cache=<current_working_dir>
|
||||
```
|
||||
|
||||
38
web/index.py
38
web/index.py
@@ -5,6 +5,8 @@ from models.stable_diffusion.main import stable_diff_inf
|
||||
# from models.diffusion.v_diffusion import vdiff_inf
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def debug_event(debug):
|
||||
@@ -149,16 +151,15 @@ with gr.Blocks() as shark_web:
|
||||
iree_vulkan_target_triple
|
||||
) = (
|
||||
live_preview
|
||||
) = debug = stable_diffusion = generated_img = std_output = None
|
||||
examples = [
|
||||
["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
[
|
||||
"A small cabin on top of a snowy mountain in the style of Disney, artstation"
|
||||
],
|
||||
]
|
||||
) = (
|
||||
debug
|
||||
) = save_img = stable_diffusion = generated_img = std_output = None
|
||||
# load prompts.
|
||||
prompt_examples = []
|
||||
prompt_loc = "./prompts.json"
|
||||
if os.path.exists(prompt_loc):
|
||||
fopen = open("./prompts.json")
|
||||
prompt_examples = json.load(fopen)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
@@ -169,7 +170,7 @@ with gr.Blocks() as shark_web:
|
||||
lines=5,
|
||||
)
|
||||
ex = gr.Examples(
|
||||
examples=examples,
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
)
|
||||
@@ -195,7 +196,12 @@ with gr.Blocks() as shark_web:
|
||||
1, 100, value=50, step=1, label="Steps"
|
||||
)
|
||||
guidance = gr.Slider(
|
||||
0, 50, value=7.5, step=0.1, label="Guidance Scale"
|
||||
0,
|
||||
50,
|
||||
value=7.5,
|
||||
step=0.1,
|
||||
label="Guidance Scale",
|
||||
interactive=False,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
@@ -232,7 +238,7 @@ with gr.Blocks() as shark_web:
|
||||
with gr.Row():
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp32",
|
||||
value="fp16",
|
||||
choices=["fp16", "fp32"],
|
||||
)
|
||||
seed = gr.Textbox(
|
||||
@@ -240,10 +246,11 @@ with gr.Blocks() as shark_web:
|
||||
)
|
||||
with gr.Row():
|
||||
cache = gr.Checkbox(label="Cache", value=True)
|
||||
debug = gr.Checkbox(label="DEBUG", value=False)
|
||||
live_preview = gr.Checkbox(
|
||||
label="Live Preview", value=False
|
||||
)
|
||||
debug = gr.Checkbox(label="DEBUG", value=False)
|
||||
save_img = gr.Checkbox(label="Save Image", value=False)
|
||||
iree_vulkan_target_triple = gr.Textbox(
|
||||
value="",
|
||||
max_lines=1,
|
||||
@@ -256,7 +263,7 @@ with gr.Blocks() as shark_web:
|
||||
std_output = gr.Textbox(
|
||||
label="Std Output",
|
||||
value="Nothing.",
|
||||
lines=10,
|
||||
lines=5,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
@@ -283,6 +290,7 @@ with gr.Blocks() as shark_web:
|
||||
cache,
|
||||
iree_vulkan_target_triple,
|
||||
live_preview,
|
||||
save_img,
|
||||
],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
|
||||
@@ -23,7 +23,6 @@ UNET_FP32 = "unet_fp32"
|
||||
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
|
||||
UNET_FP16_TUNED = "unet_fp16_tunedv2"
|
||||
|
||||
IREE_EXTRA_ARGS = []
|
||||
args = None
|
||||
|
||||
|
||||
@@ -44,6 +43,7 @@ class Arguments:
|
||||
cache: bool,
|
||||
iree_vulkan_target_triple: str,
|
||||
live_preview: bool,
|
||||
save_img: bool,
|
||||
import_mlir: bool = False,
|
||||
max_length: int = 77,
|
||||
use_tuned: bool = True,
|
||||
@@ -62,6 +62,7 @@ class Arguments:
|
||||
self.cache = cache
|
||||
self.iree_vulkan_target_triple = iree_vulkan_target_triple
|
||||
self.live_preview = live_preview
|
||||
self.save_img = save_img
|
||||
self.import_mlir = import_mlir
|
||||
self.max_length = max_length
|
||||
self.use_tuned = use_tuned
|
||||
@@ -69,9 +70,9 @@ class Arguments:
|
||||
|
||||
def get_models():
|
||||
|
||||
global IREE_EXTRA_ARGS
|
||||
global args
|
||||
|
||||
IREE_EXTRA_ARGS = []
|
||||
if args.precision == "fp16":
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
@@ -182,9 +183,10 @@ args = Arguments(
|
||||
seed=42,
|
||||
precision="fp16",
|
||||
device="vulkan",
|
||||
cache=True,
|
||||
cache=False,
|
||||
iree_vulkan_target_triple="",
|
||||
live_preview=False,
|
||||
save_img=False,
|
||||
import_mlir=False,
|
||||
max_length=77,
|
||||
use_tuned=True,
|
||||
@@ -193,6 +195,9 @@ cache_obj["vae_fp16_vulkan"], cache_obj["unet_fp16_vulkan"] = get_models()
|
||||
args.precision = "fp32"
|
||||
cache_obj["vae_fp32_vulkan"], cache_obj["unet_fp32_vulkan"] = get_models()
|
||||
|
||||
output_dir = "./stored_results/stable_diffusion"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
def stable_diff_inf(
|
||||
prompt: str,
|
||||
@@ -209,14 +214,15 @@ def stable_diff_inf(
|
||||
cache: bool,
|
||||
iree_vulkan_target_triple: str,
|
||||
live_preview: bool,
|
||||
save_img: bool,
|
||||
):
|
||||
|
||||
global IREE_EXTRA_ARGS
|
||||
global args
|
||||
global schedulers
|
||||
global cache_obj
|
||||
global output_dir
|
||||
|
||||
output_loc = f"stored_results/stable_diffusion/{time.time()}_{int(steps)}_{precision}_{device}.jpg"
|
||||
start = time.time()
|
||||
|
||||
# set seed value
|
||||
if seed == "":
|
||||
@@ -224,7 +230,9 @@ def stable_diff_inf(
|
||||
else:
|
||||
try:
|
||||
seed = int(seed)
|
||||
except ValueError:
|
||||
if seed < 0 or seed > 10000:
|
||||
seed = hash(seed)
|
||||
except (ValueError, OverflowError) as error:
|
||||
seed = hash(seed)
|
||||
|
||||
scheduler = schedulers[scheduler]
|
||||
@@ -243,12 +251,9 @@ def stable_diff_inf(
|
||||
cache,
|
||||
iree_vulkan_target_triple,
|
||||
live_preview,
|
||||
save_img,
|
||||
)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
IREE_EXTRA_ARGS.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
num_inference_steps = int(args.steps) # Number of denoising steps
|
||||
generator = torch.manual_seed(
|
||||
args.seed
|
||||
@@ -310,7 +315,7 @@ def stable_diff_inf(
|
||||
text_output = ""
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
text_output = text_output + f"\ni = {i} t = {t} "
|
||||
text_output += f"\n Iteration = {i} | Timestep = {t} | "
|
||||
step_start = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latents_numpy = latents.detach().numpy()
|
||||
@@ -323,7 +328,7 @@ def stable_diff_inf(
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
text_output = text_output + f"time={step_ms}ms"
|
||||
text_output += f"Time = {step_ms}ms."
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
|
||||
if live_preview and i % 5 == 0:
|
||||
@@ -348,8 +353,13 @@ def stable_diff_inf(
|
||||
out_img = pil_images[0]
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
text_output = text_output + f"\nAverage step time: {avg_ms}ms/it"
|
||||
text_output += f"\n\nAverage step time: {avg_ms}ms/it"
|
||||
|
||||
# save the output image with the prompt name.
|
||||
out_img.save(os.path.join(output_loc))
|
||||
total_time = time.time() - start
|
||||
text_output += f"\n\nTotal image generation time: {total_time}sec"
|
||||
|
||||
if args.save_img:
|
||||
# save outputs.
|
||||
output_loc = f"{output_dir}/{time.time()}_{int(args.steps)}_{args.precision}_{args.device}.jpg"
|
||||
out_img.save(os.path.join(output_loc))
|
||||
yield out_img, text_output
|
||||
|
||||
9
web/prompts.json
Normal file
9
web/prompts.json
Normal file
@@ -0,0 +1,9 @@
|
||||
[["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
|
||||
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
|
||||
["A beautiful castle beside a waterfall in the woods, by Josef Thoma, matte painting, trending on artstation HQ"],
|
||||
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
|
||||
["A small cabin on top of a snowy mountain in the style of Disney, artstation"]]
|
||||
Reference in New Issue
Block a user