mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
[WEB] Introduce web interface for the SHARK models (#298)
This commit introduces web application for SHARK using gradio platform. This adds web visualization of `Resnet50` and `Albert_Maskfill` models as a start. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com> Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
20
shark_web/index.py
Normal file
20
shark_web/index.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from models.resnet50 import resnet_inf
|
||||
from models.albert_maskfill import albert_maskfill_inf
|
||||
import gradio as gr
|
||||
|
||||
shark_web = gr.Blocks()
|
||||
|
||||
with shark_web:
|
||||
image = gr.Image()
|
||||
label1 = gr.Label()
|
||||
resnet = gr.Button("Recognize Image")
|
||||
|
||||
text = gr.Textbox()
|
||||
label2 = gr.Label()
|
||||
albert_mask = gr.Button("Decode Mask")
|
||||
|
||||
resnet.click(resnet_inf, inputs=image, outputs=label1)
|
||||
albert_mask.click(albert_maskfill_inf, inputs=text, outputs=label2)
|
||||
|
||||
shark_web.launch(share=True)
|
||||
|
||||
0
shark_web/models/__init__.py
Normal file
0
shark_web/models/__init__.py
Normal file
69
shark_web/models/albert_maskfill.py
Normal file
69
shark_web/models/albert_maskfill.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
class AlbertModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).logits
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
|
||||
def preprocess_data(text):
|
||||
# Preparing Data
|
||||
encoded_inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
return inputs
|
||||
|
||||
def top5_possibilities(text, inputs, token_logits):
|
||||
mask_id = torch.where(
|
||||
inputs[0] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
percentage = torch.nn.functional.softmax(mask_token_logits, dim=1)[0]
|
||||
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
|
||||
top5 = {}
|
||||
for token in top_5_tokens:
|
||||
label = text.replace(tokenizer.mask_token, tokenizer.decode(token))
|
||||
top5[label] = percentage[token].item()
|
||||
return top5
|
||||
|
||||
##############################################################################
|
||||
|
||||
def albert_maskfill_inf(masked_text):
|
||||
inputs = preprocess_data(masked_text)
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
return top5_possibilities(masked_text, inputs, token_logits)
|
||||
|
||||
54
shark_web/models/resnet50.py
Normal file
54
shark_web/models/resnet50.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
|
||||
def preprocess_image(img):
|
||||
image = Image.fromarray(img)
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
img_preprocessed = preprocess(image)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
||||
def load_labels():
|
||||
classes_text = requests.get(
|
||||
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
|
||||
stream=True,
|
||||
).text
|
||||
labels = [line.strip() for line in classes_text.splitlines()]
|
||||
return labels
|
||||
|
||||
|
||||
def top3_possibilities(res):
|
||||
labels = load_labels()
|
||||
_, indexes = torch.sort(res, descending=True)
|
||||
percentage = torch.nn.functional.softmax(res, dim=1)[0]
|
||||
top3 = dict([(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]])
|
||||
return top3
|
||||
|
||||
##############################################################################
|
||||
|
||||
def resnet_inf(numpy_img):
|
||||
img = preprocess_image(numpy_img)
|
||||
## Can pass any img or input to the forward module.
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
|
||||
# print("The top 3 results obtained via shark_runner is:")
|
||||
return top3_possibilities(torch.from_numpy(result))
|
||||
Reference in New Issue
Block a user