mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
This commit adds adds `v_diffusion` model web visualization as a part of shark web. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com> Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
46 lines
1.8 KiB
Python
46 lines
1.8 KiB
Python
from models.resnet50 import resnet_inf
|
|
from models.albert_maskfill import albert_maskfill_inf
|
|
from models.diffusion.v_diffusion import vdiff_inf
|
|
import gradio as gr
|
|
|
|
shark_web = gr.Blocks()
|
|
|
|
with shark_web:
|
|
with gr.Row():
|
|
with gr.Column():
|
|
with gr.Group():
|
|
image = gr.Image(label="Image")
|
|
label = gr.Label(label="Output")
|
|
resnet = gr.Button("Recognize Image")
|
|
resnet.click(resnet_inf, inputs=image, outputs=label)
|
|
with gr.Column():
|
|
with gr.Group():
|
|
masked_text = gr.Textbox(
|
|
label="Masked Text",
|
|
placeholder="Give me a sentence with [MASK] to fill",
|
|
)
|
|
decoded_res = gr.Label(label="Decoded Results")
|
|
albert_mask = gr.Button("Decode Mask")
|
|
albert_mask.click(
|
|
albert_maskfill_inf,
|
|
inputs=masked_text,
|
|
outputs=decoded_res,
|
|
)
|
|
with gr.Column():
|
|
with gr.Group():
|
|
prompt = gr.Textbox(
|
|
label="Prompt", value="New York City, oil on canvas:5"
|
|
)
|
|
sample_count = gr.Number(label="Sample Count", value=1)
|
|
batch_size = gr.Number(label="Batch Size", value=1)
|
|
iters = gr.Number(label="Steps", value=2)
|
|
v_diffusion = gr.Button("Generate image from prompt")
|
|
generated_img = gr.Image(type="pil", shape=(100, 100))
|
|
v_diffusion.click(
|
|
vdiff_inf,
|
|
inputs=[prompt, sample_count, batch_size, iters],
|
|
outputs=generated_img,
|
|
)
|
|
|
|
shark_web.launch(share=True, server_port=8080)
|