[WEB] Introduce web interface for the SHARK models (#298)

This commit introduces web application for SHARK using gradio platform.
This adds web visualization of `Resnet50` and `Albert_Maskfill` models
as a start.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-09-01 11:47:52 +05:30
committed by GitHub
parent 3703f014d9
commit fe080eaee6
4 changed files with 143 additions and 0 deletions

20
shark_web/index.py Normal file
View File

@@ -0,0 +1,20 @@
from models.resnet50 import resnet_inf
from models.albert_maskfill import albert_maskfill_inf
import gradio as gr
shark_web = gr.Blocks()
with shark_web:
image = gr.Image()
label1 = gr.Label()
resnet = gr.Button("Recognize Image")
text = gr.Textbox()
label2 = gr.Label()
albert_mask = gr.Button("Decode Mask")
resnet.click(resnet_inf, inputs=image, outputs=label1)
albert_mask.click(albert_maskfill_inf, inputs=text, outputs=label2)
shark_web.launch(share=True)

View File

View File

@@ -0,0 +1,69 @@
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
import numpy as np
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
class AlbertModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForMaskedLM.from_pretrained("albert-base-v2")
self.model.eval()
def forward(self, input_ids, attention_mask):
return self.model(
input_ids=input_ids, attention_mask=attention_mask
).logits
################################## Preprocessing inputs and model ############
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
def preprocess_data(text):
# Preparing Data
encoded_inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
return_tensors="pt",
)
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
return inputs
def top5_possibilities(text, inputs, token_logits):
mask_id = torch.where(
inputs[0] == tokenizer.mask_token_id
)[1]
mask_token_logits = token_logits[0, mask_id, :]
percentage = torch.nn.functional.softmax(mask_token_logits, dim=1)[0]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
top5 = {}
for token in top_5_tokens:
label = text.replace(tokenizer.mask_token, tokenizer.decode(token))
top5[label] = percentage[token].item()
return top5
##############################################################################
def albert_maskfill_inf(masked_text):
inputs = preprocess_data(masked_text)
mlir_importer = SharkImporter(
AlbertModule(),
inputs,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, mlir_dialect="linalg"
)
shark_module.compile()
token_logits = torch.tensor(shark_module.forward(inputs))
return top5_possibilities(masked_text, inputs, token_logits)

View File

@@ -0,0 +1,54 @@
from PIL import Image
import requests
import torch
from torchvision import transforms
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
################################## Preprocessing inputs and model ############
def preprocess_image(img):
image = Image.fromarray(img)
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
img_preprocessed = preprocess(image)
return torch.unsqueeze(img_preprocessed, 0)
def load_labels():
classes_text = requests.get(
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
stream=True,
).text
labels = [line.strip() for line in classes_text.splitlines()]
return labels
def top3_possibilities(res):
labels = load_labels()
_, indexes = torch.sort(res, descending=True)
percentage = torch.nn.functional.softmax(res, dim=1)[0]
top3 = dict([(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]])
return top3
##############################################################################
def resnet_inf(numpy_img):
img = preprocess_image(numpy_img)
## Can pass any img or input to the forward module.
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((img.detach().numpy(),))
# print("The top 3 results obtained via shark_runner is:")
return top3_possibilities(torch.from_numpy(result))