[WEB] Add shark-web logging

1. This commit adds support to display logs in the shark-web.
2. It also adds nod logo in the home page.
3. Stable-diffusion outputs are being saved now.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-09-28 18:03:32 +05:30
parent 56f8a0d85a
commit 0013fb0753
9 changed files with 135 additions and 35 deletions

BIN
web/Nod_logo.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View File

@@ -4,25 +4,49 @@ from models.stable_diffusion import stable_diff_inf
# from models.diffusion.v_diffusion import vdiff_inf
import gradio as gr
from PIL import Image
def debug_event(debug):
return gr.Textbox.update(visible=debug)
with gr.Blocks() as shark_web:
gr.Markdown("Shark Models Demo.")
with gr.Tabs():
with gr.Row():
with gr.Group():
with gr.Column(scale=1):
img = Image.open("./Nod_logo.jpg")
gr.Image(value=img, show_label=False, interactive=False).style(
height=70, width=70
)
with gr.Column(scale=9):
gr.Label(value="Shark Models Demo.")
with gr.Tabs():
with gr.TabItem("ResNet50"):
image = device = resnet = output = None
image = device = debug = resnet = output = std_output = None
with gr.Row():
with gr.Column(scale=1, min_width=600):
image = gr.Image(label="Image")
device = gr.Textbox(label="Device", value="cpu")
debug = gr.Checkbox(label="DEBUG", value=False)
resnet = gr.Button("Recognize Image").style(
full_width=True
)
with gr.Column(scale=1, min_width=600):
output = gr.Label(label="Output")
std_output = gr.Textbox(
label="Std Output", value="Nothing."
label="Std Output",
value="Nothing to show.",
visible=False,
)
debug.change(
debug_event,
inputs=[debug],
outputs=[std_output],
show_progress=False,
)
resnet.click(
resnet_inf,
inputs=[image, device],
@@ -30,7 +54,9 @@ with gr.Blocks() as shark_web:
)
with gr.TabItem("Albert MaskFill"):
masked_text = device = albert_mask = decoded_res = None
masked_text = (
device
) = debug = albert_mask = decoded_res = std_output = None
with gr.Row():
with gr.Column(scale=1, min_width=600):
masked_text = gr.Textbox(
@@ -38,12 +64,21 @@ with gr.Blocks() as shark_web:
placeholder="Give me a sentence with [MASK] to fill",
)
device = gr.Textbox(label="Device", value="cpu")
debug = gr.Checkbox(label="DEBUG", value=False)
albert_mask = gr.Button("Decode Mask")
with gr.Column(scale=1, min_width=600):
decoded_res = gr.Label(label="Decoded Results")
std_output = gr.Textbox(
label="Std Output", value="Nothing."
label="Std Output",
value="Nothing to show.",
visible=False,
)
debug.change(
debug_event,
inputs=[debug],
outputs=[std_output],
show_progress=False,
)
albert_mask.click(
albert_maskfill_inf,
inputs=[masked_text, device],
@@ -74,7 +109,9 @@ with gr.Blocks() as shark_web:
with gr.TabItem("Stable-Diffusion"):
prompt = (
iters
) = mlir_loc = device = stable_diffusion = generated_img = None
) = (
device
) = debug = stable_diffusion = generated_img = std_output = None
with gr.Row():
with gr.Column(scale=1, min_width=600):
prompt = gr.Textbox(
@@ -82,20 +119,23 @@ with gr.Blocks() as shark_web:
value="a photograph of an astronaut riding a horse",
)
iters = gr.Number(label="Steps", value=2)
mlir_loc = gr.Textbox(
label="Location of MLIR(Relative to SHARK/web/)",
value="./stable_diffusion.mlir",
)
device = gr.Textbox(label="Device", value="vulkan")
debug = gr.Checkbox(label="DEBUG", value=False)
stable_diffusion = gr.Button("Generate image from prompt")
with gr.Column(scale=1, min_width=600):
generated_img = gr.Image(type="pil", shape=(100, 100))
std_output = gr.Textbox(
label="Std Output", value="Nothing."
label="Std Output", value="Nothing.", visible=False
)
debug.change(
debug_event,
inputs=[debug],
outputs=[std_output],
show_progress=False,
)
stable_diffusion.click(
stable_diff_inf,
inputs=[prompt, iters, mlir_loc, device],
inputs=[prompt, iters, device],
outputs=[generated_img, std_output],
)

View File

0
web/logs/resnet50_log.txt Executable file
View File

View File

View File

@@ -21,6 +21,7 @@ class AlbertModule(torch.nn.Module):
################################## Preprocessing inputs ####################
DEBUG = False
compiled_module = {}
compiled_module["tokenizer"] = AutoTokenizer.from_pretrained("albert-base-v2")
@@ -42,10 +43,13 @@ def preprocess_data(text):
return inputs
def top5_possibilities(text, inputs, token_logits):
def top5_possibilities(text, inputs, token_logits, log_write):
global DEBUG
global compiled_module
if DEBUG:
log_write.write("Retrieving top 5 possible outcomes.\n")
tokenizer = compiled_module["tokenizer"]
mask_id = torch.where(inputs[0] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_id, :]
@@ -55,6 +59,8 @@ def top5_possibilities(text, inputs, token_logits):
for token in top_5_tokens:
label = text.replace(tokenizer.mask_token, tokenizer.decode(token))
top5[label] = percentage[token].item()
if DEBUG:
log_write.write("Done.\n")
return top5
@@ -63,10 +69,18 @@ def top5_possibilities(text, inputs, token_logits):
def albert_maskfill_inf(masked_text, device):
global DEBUG
global compiled_module
DEBUG = False
log_write = open(r"logs/albert_maskfill_log.txt", "w")
if log_write:
DEBUG = True
inputs = preprocess_data(masked_text)
if device not in compiled_module.keys():
if DEBUG:
log_write.write("Compiling the Albert Maskfill module.\n")
mlir_importer = SharkImporter(
AlbertModule(),
inputs,
@@ -80,6 +94,15 @@ def albert_maskfill_inf(masked_text, device):
)
shark_module.compile()
compiled_module[device] = shark_module
if DEBUG:
log_write.write("Compilation successful.\n")
token_logits = torch.tensor(compiled_module[device].forward(inputs))
return top5_possibilities(masked_text, inputs, token_logits), "Testing.."
output = top5_possibilities(masked_text, inputs, token_logits, log_write)
log_write.close()
std_output = ""
with open(r"logs/albert_maskfill_log.txt", "r") as log_read:
std_output = log_read.read()
return output, std_output

View File

@@ -7,6 +7,9 @@ from shark.shark_downloader import download_torch_model
################################## Preprocessing inputs and helper functions ########
DEBUG = False
compiled_module = {}
def preprocess_image(img):
image = Image.fromarray(img)
@@ -33,43 +36,57 @@ def load_labels():
return labels
def top3_possibilities(res):
def top3_possibilities(res, log_write):
global DEBUG
if DEBUG:
log_write.write("Retrieving top 3 possible outcomes.\n")
labels = load_labels()
_, indexes = torch.sort(res, descending=True)
percentage = torch.nn.functional.softmax(res, dim=1)[0]
top3 = dict(
[(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
)
if DEBUG:
log_write.write("Done.\n")
return top3
##############################################################################
compiled_module = {}
def resnet_inf(numpy_img, device):
global DEBUG
global compiled_module
std_output = ""
DEBUG = False
log_write = open(r"logs/resnet50_log.txt", "w")
if log_write:
DEBUG = True
if device not in compiled_module.keys():
std_output += "Compiling the Resnet50 module.\n"
if DEBUG:
log_write.write("Compiling the Resnet50 module.\n")
mlir_model, func_name, inputs, golden_out = download_torch_model(
"resnet50"
)
shark_module = SharkInference(
mlir_model, func_name, device=device, mlir_dialect="linalg"
)
shark_module.compile()
std_output += "Compilation successful.\n"
compiled_module[device] = shark_module
if DEBUG:
log_write.write("Compilation successful.\n")
img = preprocess_image(numpy_img)
result = compiled_module[device].forward((img.detach().numpy(),))
output = top3_possibilities(torch.from_numpy(result), log_write)
log_write.close()
# print("The top 3 results obtained via shark_runner is:")
std_output += "Retrieving top 3 possible outcomes.\n"
return top3_possibilities(torch.from_numpy(result)), std_output
std_output = ""
with open(r"logs/resnet50_log.txt", "r") as log_read:
std_output = log_read.read()
return output, std_output

View File

@@ -10,16 +10,14 @@ from torch._decomp import get_decompositions
import torch_mlir
import tempfile
import numpy as np
import os
##############################################################################
def load_mlir(mlir_loc):
import os
if mlir_loc == None:
return None
print(f"Trying to load the model from {mlir_loc}.")
with open(os.path.join(mlir_loc)) as f:
mlir_module = f.read()
return mlir_module
@@ -85,21 +83,30 @@ def compile_through_fx(model, inputs, device, mlir_loc=None):
##############################################################################
DEBUG = False
compiled_module = {}
def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
def stable_diff_inf(prompt: str, steps, device: str):
args = {}
args["prompt"] = [prompt]
args["steps"] = steps
args["device"] = device
args["mlir_loc"] = mlir_loc
args["mlir_loc"] = "./stable_diffusion.mlir"
output_loc = (
f"stored_results/stable_diffusion/{prompt}_{int(steps)}_{device}.jpg"
)
global DEBUG
global compiled_module
if args["device"] not in compiled_module.keys():
DEBUG = False
log_write = open(r"logs/stable_diffusion_log.txt", "w")
if log_write:
DEBUG = True
if args["device"] not in compiled_module.keys():
YOUR_TOKEN = "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"
# 1. Load the autoencoder model which will be used to decode the latents into image space.
@@ -116,6 +123,8 @@ def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
compiled_module["text_encoder"] = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
if DEBUG:
log_write.write("Compiling the Unet module.\n")
# Wrap the unet model to return tuples.
class UnetModel(torch.nn.Module):
@@ -143,14 +152,16 @@ def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
args["mlir_loc"],
)
compiled_module[args["device"]] = shark_unet
if DEBUG:
log_write.write("Compilation successful.\n")
compiled_module["unet"] = unet
compiled_module["scheduler"] = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
compiled_module["unet"] = unet
shark_unet = compiled_module[args["device"]]
vae = compiled_module["vae"]
@@ -202,7 +213,8 @@ def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
for i, t in tqdm(enumerate(scheduler.timesteps)):
print(f"i = {i} t = {t}")
if DEBUG:
log_write.write(f"i = {i} t = {t}\n")
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
@@ -232,11 +244,19 @@ def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
print(latents.shape)
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images[0], "Testing.."
output = pil_images[0]
# save the output image with the prompt name.
output.save(os.path.join(output_loc))
log_write.close()
std_output = ""
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
std_output = log_read.read()
return output, std_output