mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[WEB] Add shark-web logging
1. This commit adds support to display logs in the shark-web. 2. It also adds nod logo in the home page. 3. Stable-diffusion outputs are being saved now. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
BIN
web/Nod_logo.jpg
Normal file
BIN
web/Nod_logo.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 41 KiB |
66
web/index.py
66
web/index.py
@@ -4,25 +4,49 @@ from models.stable_diffusion import stable_diff_inf
|
||||
|
||||
# from models.diffusion.v_diffusion import vdiff_inf
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def debug_event(debug):
|
||||
return gr.Textbox.update(visible=debug)
|
||||
|
||||
|
||||
with gr.Blocks() as shark_web:
|
||||
gr.Markdown("Shark Models Demo.")
|
||||
with gr.Tabs():
|
||||
|
||||
with gr.Row():
|
||||
with gr.Group():
|
||||
with gr.Column(scale=1):
|
||||
img = Image.open("./Nod_logo.jpg")
|
||||
gr.Image(value=img, show_label=False, interactive=False).style(
|
||||
height=70, width=70
|
||||
)
|
||||
with gr.Column(scale=9):
|
||||
gr.Label(value="Shark Models Demo.")
|
||||
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("ResNet50"):
|
||||
image = device = resnet = output = None
|
||||
image = device = debug = resnet = output = std_output = None
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
image = gr.Image(label="Image")
|
||||
device = gr.Textbox(label="Device", value="cpu")
|
||||
debug = gr.Checkbox(label="DEBUG", value=False)
|
||||
resnet = gr.Button("Recognize Image").style(
|
||||
full_width=True
|
||||
)
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
output = gr.Label(label="Output")
|
||||
std_output = gr.Textbox(
|
||||
label="Std Output", value="Nothing."
|
||||
label="Std Output",
|
||||
value="Nothing to show.",
|
||||
visible=False,
|
||||
)
|
||||
debug.change(
|
||||
debug_event,
|
||||
inputs=[debug],
|
||||
outputs=[std_output],
|
||||
show_progress=False,
|
||||
)
|
||||
resnet.click(
|
||||
resnet_inf,
|
||||
inputs=[image, device],
|
||||
@@ -30,7 +54,9 @@ with gr.Blocks() as shark_web:
|
||||
)
|
||||
|
||||
with gr.TabItem("Albert MaskFill"):
|
||||
masked_text = device = albert_mask = decoded_res = None
|
||||
masked_text = (
|
||||
device
|
||||
) = debug = albert_mask = decoded_res = std_output = None
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
masked_text = gr.Textbox(
|
||||
@@ -38,12 +64,21 @@ with gr.Blocks() as shark_web:
|
||||
placeholder="Give me a sentence with [MASK] to fill",
|
||||
)
|
||||
device = gr.Textbox(label="Device", value="cpu")
|
||||
debug = gr.Checkbox(label="DEBUG", value=False)
|
||||
albert_mask = gr.Button("Decode Mask")
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
decoded_res = gr.Label(label="Decoded Results")
|
||||
std_output = gr.Textbox(
|
||||
label="Std Output", value="Nothing."
|
||||
label="Std Output",
|
||||
value="Nothing to show.",
|
||||
visible=False,
|
||||
)
|
||||
debug.change(
|
||||
debug_event,
|
||||
inputs=[debug],
|
||||
outputs=[std_output],
|
||||
show_progress=False,
|
||||
)
|
||||
albert_mask.click(
|
||||
albert_maskfill_inf,
|
||||
inputs=[masked_text, device],
|
||||
@@ -74,7 +109,9 @@ with gr.Blocks() as shark_web:
|
||||
with gr.TabItem("Stable-Diffusion"):
|
||||
prompt = (
|
||||
iters
|
||||
) = mlir_loc = device = stable_diffusion = generated_img = None
|
||||
) = (
|
||||
device
|
||||
) = debug = stable_diffusion = generated_img = std_output = None
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
prompt = gr.Textbox(
|
||||
@@ -82,20 +119,23 @@ with gr.Blocks() as shark_web:
|
||||
value="a photograph of an astronaut riding a horse",
|
||||
)
|
||||
iters = gr.Number(label="Steps", value=2)
|
||||
mlir_loc = gr.Textbox(
|
||||
label="Location of MLIR(Relative to SHARK/web/)",
|
||||
value="./stable_diffusion.mlir",
|
||||
)
|
||||
device = gr.Textbox(label="Device", value="vulkan")
|
||||
debug = gr.Checkbox(label="DEBUG", value=False)
|
||||
stable_diffusion = gr.Button("Generate image from prompt")
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
generated_img = gr.Image(type="pil", shape=(100, 100))
|
||||
std_output = gr.Textbox(
|
||||
label="Std Output", value="Nothing."
|
||||
label="Std Output", value="Nothing.", visible=False
|
||||
)
|
||||
debug.change(
|
||||
debug_event,
|
||||
inputs=[debug],
|
||||
outputs=[std_output],
|
||||
show_progress=False,
|
||||
)
|
||||
stable_diffusion.click(
|
||||
stable_diff_inf,
|
||||
inputs=[prompt, iters, mlir_loc, device],
|
||||
inputs=[prompt, iters, device],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
|
||||
|
||||
0
web/logs/albert_maskfill_log.txt
Executable file
0
web/logs/albert_maskfill_log.txt
Executable file
0
web/logs/resnet50_log.txt
Executable file
0
web/logs/resnet50_log.txt
Executable file
0
web/logs/stable_diffusion_log.txt
Executable file
0
web/logs/stable_diffusion_log.txt
Executable file
@@ -21,6 +21,7 @@ class AlbertModule(torch.nn.Module):
|
||||
|
||||
################################## Preprocessing inputs ####################
|
||||
|
||||
DEBUG = False
|
||||
compiled_module = {}
|
||||
compiled_module["tokenizer"] = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
|
||||
@@ -42,10 +43,13 @@ def preprocess_data(text):
|
||||
return inputs
|
||||
|
||||
|
||||
def top5_possibilities(text, inputs, token_logits):
|
||||
def top5_possibilities(text, inputs, token_logits, log_write):
|
||||
|
||||
global DEBUG
|
||||
global compiled_module
|
||||
|
||||
if DEBUG:
|
||||
log_write.write("Retrieving top 5 possible outcomes.\n")
|
||||
tokenizer = compiled_module["tokenizer"]
|
||||
mask_id = torch.where(inputs[0] == tokenizer.mask_token_id)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
@@ -55,6 +59,8 @@ def top5_possibilities(text, inputs, token_logits):
|
||||
for token in top_5_tokens:
|
||||
label = text.replace(tokenizer.mask_token, tokenizer.decode(token))
|
||||
top5[label] = percentage[token].item()
|
||||
if DEBUG:
|
||||
log_write.write("Done.\n")
|
||||
return top5
|
||||
|
||||
|
||||
@@ -63,10 +69,18 @@ def top5_possibilities(text, inputs, token_logits):
|
||||
|
||||
def albert_maskfill_inf(masked_text, device):
|
||||
|
||||
global DEBUG
|
||||
global compiled_module
|
||||
|
||||
DEBUG = False
|
||||
log_write = open(r"logs/albert_maskfill_log.txt", "w")
|
||||
if log_write:
|
||||
DEBUG = True
|
||||
|
||||
inputs = preprocess_data(masked_text)
|
||||
if device not in compiled_module.keys():
|
||||
if DEBUG:
|
||||
log_write.write("Compiling the Albert Maskfill module.\n")
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
@@ -80,6 +94,15 @@ def albert_maskfill_inf(masked_text, device):
|
||||
)
|
||||
shark_module.compile()
|
||||
compiled_module[device] = shark_module
|
||||
if DEBUG:
|
||||
log_write.write("Compilation successful.\n")
|
||||
|
||||
token_logits = torch.tensor(compiled_module[device].forward(inputs))
|
||||
return top5_possibilities(masked_text, inputs, token_logits), "Testing.."
|
||||
output = top5_possibilities(masked_text, inputs, token_logits, log_write)
|
||||
log_write.close()
|
||||
|
||||
std_output = ""
|
||||
with open(r"logs/albert_maskfill_log.txt", "r") as log_read:
|
||||
std_output = log_read.read()
|
||||
|
||||
return output, std_output
|
||||
|
||||
@@ -7,6 +7,9 @@ from shark.shark_downloader import download_torch_model
|
||||
|
||||
################################## Preprocessing inputs and helper functions ########
|
||||
|
||||
DEBUG = False
|
||||
compiled_module = {}
|
||||
|
||||
|
||||
def preprocess_image(img):
|
||||
image = Image.fromarray(img)
|
||||
@@ -33,43 +36,57 @@ def load_labels():
|
||||
return labels
|
||||
|
||||
|
||||
def top3_possibilities(res):
|
||||
def top3_possibilities(res, log_write):
|
||||
|
||||
global DEBUG
|
||||
|
||||
if DEBUG:
|
||||
log_write.write("Retrieving top 3 possible outcomes.\n")
|
||||
labels = load_labels()
|
||||
_, indexes = torch.sort(res, descending=True)
|
||||
percentage = torch.nn.functional.softmax(res, dim=1)[0]
|
||||
top3 = dict(
|
||||
[(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
||||
)
|
||||
if DEBUG:
|
||||
log_write.write("Done.\n")
|
||||
return top3
|
||||
|
||||
|
||||
##############################################################################
|
||||
|
||||
compiled_module = {}
|
||||
|
||||
|
||||
def resnet_inf(numpy_img, device):
|
||||
|
||||
global DEBUG
|
||||
global compiled_module
|
||||
|
||||
std_output = ""
|
||||
DEBUG = False
|
||||
log_write = open(r"logs/resnet50_log.txt", "w")
|
||||
if log_write:
|
||||
DEBUG = True
|
||||
|
||||
if device not in compiled_module.keys():
|
||||
std_output += "Compiling the Resnet50 module.\n"
|
||||
if DEBUG:
|
||||
log_write.write("Compiling the Resnet50 module.\n")
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"resnet50"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
std_output += "Compilation successful.\n"
|
||||
compiled_module[device] = shark_module
|
||||
if DEBUG:
|
||||
log_write.write("Compilation successful.\n")
|
||||
|
||||
img = preprocess_image(numpy_img)
|
||||
result = compiled_module[device].forward((img.detach().numpy(),))
|
||||
output = top3_possibilities(torch.from_numpy(result), log_write)
|
||||
log_write.close()
|
||||
|
||||
# print("The top 3 results obtained via shark_runner is:")
|
||||
std_output += "Retrieving top 3 possible outcomes.\n"
|
||||
return top3_possibilities(torch.from_numpy(result)), std_output
|
||||
std_output = ""
|
||||
with open(r"logs/resnet50_log.txt", "r") as log_read:
|
||||
std_output = log_read.read()
|
||||
|
||||
return output, std_output
|
||||
|
||||
@@ -10,16 +10,14 @@ from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
##############################################################################
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
@@ -85,21 +83,30 @@ def compile_through_fx(model, inputs, device, mlir_loc=None):
|
||||
|
||||
##############################################################################
|
||||
|
||||
DEBUG = False
|
||||
compiled_module = {}
|
||||
|
||||
|
||||
def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
|
||||
def stable_diff_inf(prompt: str, steps, device: str):
|
||||
|
||||
args = {}
|
||||
args["prompt"] = [prompt]
|
||||
args["steps"] = steps
|
||||
args["device"] = device
|
||||
args["mlir_loc"] = mlir_loc
|
||||
args["mlir_loc"] = "./stable_diffusion.mlir"
|
||||
output_loc = (
|
||||
f"stored_results/stable_diffusion/{prompt}_{int(steps)}_{device}.jpg"
|
||||
)
|
||||
|
||||
global DEBUG
|
||||
global compiled_module
|
||||
|
||||
if args["device"] not in compiled_module.keys():
|
||||
DEBUG = False
|
||||
log_write = open(r"logs/stable_diffusion_log.txt", "w")
|
||||
if log_write:
|
||||
DEBUG = True
|
||||
|
||||
if args["device"] not in compiled_module.keys():
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
||||
@@ -116,6 +123,8 @@ def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
|
||||
compiled_module["text_encoder"] = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if DEBUG:
|
||||
log_write.write("Compiling the Unet module.\n")
|
||||
|
||||
# Wrap the unet model to return tuples.
|
||||
class UnetModel(torch.nn.Module):
|
||||
@@ -143,14 +152,16 @@ def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
|
||||
args["mlir_loc"],
|
||||
)
|
||||
compiled_module[args["device"]] = shark_unet
|
||||
if DEBUG:
|
||||
log_write.write("Compilation successful.\n")
|
||||
|
||||
compiled_module["unet"] = unet
|
||||
compiled_module["scheduler"] = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
compiled_module["unet"] = unet
|
||||
|
||||
shark_unet = compiled_module[args["device"]]
|
||||
vae = compiled_module["vae"]
|
||||
@@ -202,7 +213,8 @@ def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
|
||||
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
print(f"i = {i} t = {t}")
|
||||
if DEBUG:
|
||||
log_write.write(f"i = {i} t = {t}\n")
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
sigma = scheduler.sigmas[i]
|
||||
@@ -232,11 +244,19 @@ def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
print(latents.shape)
|
||||
image = vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
return pil_images[0], "Testing.."
|
||||
output = pil_images[0]
|
||||
# save the output image with the prompt name.
|
||||
output.save(os.path.join(output_loc))
|
||||
log_write.close()
|
||||
|
||||
std_output = ""
|
||||
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
|
||||
std_output = log_read.read()
|
||||
|
||||
return output, std_output
|
||||
|
||||
0
web/stored_results/stable_diffusion/empty.jpg
Normal file
0
web/stored_results/stable_diffusion/empty.jpg
Normal file
Reference in New Issue
Block a user