Files
AMD-SHARK-Studio/web/index.py
Gaurav Shukla 4b1a0b43ff [WEB] Remove long prompts support
It removes support to long prompts due to higher lag in loading long prompts.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs>
2022-11-03 18:57:58 +05:30

242 lines
8.2 KiB
Python

# from models.resnet50 import resnet_inf
# from models.albert_maskfill import albert_maskfill_inf
from models.stable_diffusion.main import stable_diff_inf
# from models.diffusion.v_diffusion import vdiff_inf
import gradio as gr
from PIL import Image
import json
import os
def debug_event(debug):
return gr.Textbox.update(visible=debug)
prompt_examples = []
prompt_loc = "./prompts.json"
if os.path.exists(prompt_loc):
with open("./prompts.json", encoding="utf-8") as fopen:
prompt_examples = json.load(fopen)
demo_css = """
.gradio-container {background-color: black}
.container {background-color: black !important; padding-top:20px !important; }
#ui_title {padding: 10px !important; }
#top_logo {background-color: transparent; border-radius: 0 !important; border: 0; }
#demo_title {background-color: black; border-radius: 0 !important; border: 0; padding-top: 50px; padding-bottom: 0px; width: 460px !important;}
#demo_title_outer {border-radius: 0; }
#prompt_box_outer div:first-child {border-radius: 0 !important}
#prompt_box textarea {background-color:#1d1d1d !important}
#prompt_examples {margin:0 !important}
#prompt_examples svg {display: none !important;}
.gr-sample-textbox { border-radius: 1rem !important; border-color: rgb(31,41,55) !important; border-width:2px !important; }
#ui_body {background-color: #111111 !important; padding: 10px !important; border-radius: 0.5em !important;}
#img_result+div {display: none !important;}
footer {display: none !important;}
"""
with gr.Blocks(css=demo_css) as shark_web:
# load prompt examples.
with gr.Row(elem_id="ui_title"):
with gr.Column(scale=1, elem_id="demo_title_outer"):
logo2 = Image.open("./logos/sd-demo-logo.png")
gr.Image(
value=logo2,
show_label=False,
interactive=False,
elem_id="demo_title",
).style(width=230)
# with gr.Column(scale=1):
# gr.Label(value="Ultra fast Stable Diffusion")
with gr.Row(elem_id="ui_body"):
prompt = (
scheduler
) = (
iters_count
) = (
batch_size
) = (
steps
) = (
guidance
) = (
height
) = (
width
) = (
seed
) = (
precision
) = (
device
) = (
cache
) = (
iree_vulkan_target_triple
) = (
live_preview
) = (
debug
) = save_img = stable_diffusion = generated_img = std_output = None
# load prompts.
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value="A photograph of an astronaut riding a horse",
lines=1,
elem_id="prompt_box",
)
with gr.Group():
ex = gr.Examples(
label="Examples",
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Row():
steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
guidance = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="Guidance Scale",
interactive=False,
)
with gr.Row():
height = gr.Slider(
384,
768,
value=512,
step=64,
label="Height",
interactive=False,
)
width = gr.Slider(
384,
768,
value=512,
step=64,
label="Width",
interactive=False,
)
with gr.Row():
precision = gr.Radio(
label="Precision",
value="fp16",
choices=["fp16", "fp32"],
)
seed = gr.Textbox(value="42", max_lines=1, label="Seed")
with gr.Row():
cache = gr.Checkbox(label="Cache", value=True)
# debug = gr.Checkbox(label="DEBUG", value=False)
save_img = gr.Checkbox(label="Save Image", value=False)
live_preview = gr.Checkbox(
label="Live Preview", value=False
)
# Hidden Items.
scheduler = gr.Radio(
label="Scheduler",
value="LMS",
choices=["PNDM", "LMS", "DDIM"],
interactive=False,
visible=False,
)
device = gr.Radio(
label="Device",
value="vulkan",
choices=["cpu", "cuda", "vulkan"],
interactive=False,
visible=False,
elem_id="ugly_line",
)
iters_count = gr.Slider(
1,
24,
value=1,
step=1,
label="Iteration Count",
visible=False,
)
batch_size = gr.Slider(
1,
4,
value=1,
step=1,
label="Batch Size",
visible=False,
)
iree_vulkan_target_triple = gr.Textbox(
value="",
max_lines=1,
label="IREE VULKAN TARGET TRIPLE",
visible=False,
elem_id="ugly_line",
)
stable_diffusion = gr.Button("Generate Image")
# logo
nod_logo = Image.open("./logos/amd-nod-logo.png")
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=230)
with gr.Column(scale=1, min_width=600):
generated_img = gr.Image(
type="pil", elem_id="img_result", interactive=False
).style(height=768, width=768)
std_output = gr.Textbox(
label="Std Output",
value="Nothing.",
lines=5,
visible=False,
elem_id="ugly_line",
)
"""
debug.change(
debug_event,
inputs=[debug],
outputs=[std_output],
show_progress=False,
)
"""
stable_diffusion.click(
stable_diff_inf,
inputs=[
prompt,
scheduler,
iters_count,
batch_size,
steps,
guidance,
height,
width,
seed,
precision,
device,
cache,
iree_vulkan_target_triple,
live_preview,
save_img,
],
outputs=[generated_img, std_output],
show_progress=False,
)
shark_web.queue()
shark_web.launch(server_name="0.0.0.0", server_port=8080, enable_queue=True)