Files
AMD-SHARK-Studio/web/index.py
Gaurav Shukla a886cba655 [WEB] Add v_diffusion model in the shark web (#306)
This commit adds adds `v_diffusion` model web visualization as a part of
shark web.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-09-01 06:34:51 -07:00

46 lines
1.8 KiB
Python

from models.resnet50 import resnet_inf
from models.albert_maskfill import albert_maskfill_inf
from models.diffusion.v_diffusion import vdiff_inf
import gradio as gr
shark_web = gr.Blocks()
with shark_web:
with gr.Row():
with gr.Column():
with gr.Group():
image = gr.Image(label="Image")
label = gr.Label(label="Output")
resnet = gr.Button("Recognize Image")
resnet.click(resnet_inf, inputs=image, outputs=label)
with gr.Column():
with gr.Group():
masked_text = gr.Textbox(
label="Masked Text",
placeholder="Give me a sentence with [MASK] to fill",
)
decoded_res = gr.Label(label="Decoded Results")
albert_mask = gr.Button("Decode Mask")
albert_mask.click(
albert_maskfill_inf,
inputs=masked_text,
outputs=decoded_res,
)
with gr.Column():
with gr.Group():
prompt = gr.Textbox(
label="Prompt", value="New York City, oil on canvas:5"
)
sample_count = gr.Number(label="Sample Count", value=1)
batch_size = gr.Number(label="Batch Size", value=1)
iters = gr.Number(label="Steps", value=2)
v_diffusion = gr.Button("Generate image from prompt")
generated_img = gr.Image(type="pil", shape=(100, 100))
v_diffusion.click(
vdiff_inf,
inputs=[prompt, sample_count, batch_size, iters],
outputs=generated_img,
)
shark_web.launch(share=True, server_port=8080)