From fe080eaee634b964ac596221c5ccba29023cd5f1 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Thu, 1 Sep 2022 11:47:52 +0530 Subject: [PATCH] [WEB] Introduce web interface for the SHARK models (#298) This commit introduces web application for SHARK using gradio platform. This adds web visualization of `Resnet50` and `Albert_Maskfill` models as a start. Signed-Off-by: Gaurav Shukla Signed-off-by: Gaurav Shukla --- shark_web/index.py | 20 +++++++++ shark_web/models/__init__.py | 0 shark_web/models/albert_maskfill.py | 69 +++++++++++++++++++++++++++++ shark_web/models/resnet50.py | 54 ++++++++++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 shark_web/index.py create mode 100644 shark_web/models/__init__.py create mode 100644 shark_web/models/albert_maskfill.py create mode 100644 shark_web/models/resnet50.py diff --git a/shark_web/index.py b/shark_web/index.py new file mode 100644 index 00000000..5230bf12 --- /dev/null +++ b/shark_web/index.py @@ -0,0 +1,20 @@ +from models.resnet50 import resnet_inf +from models.albert_maskfill import albert_maskfill_inf +import gradio as gr + +shark_web = gr.Blocks() + +with shark_web: + image = gr.Image() + label1 = gr.Label() + resnet = gr.Button("Recognize Image") + + text = gr.Textbox() + label2 = gr.Label() + albert_mask = gr.Button("Decode Mask") + + resnet.click(resnet_inf, inputs=image, outputs=label1) + albert_mask.click(albert_maskfill_inf, inputs=text, outputs=label2) + +shark_web.launch(share=True) + diff --git a/shark_web/models/__init__.py b/shark_web/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/shark_web/models/albert_maskfill.py b/shark_web/models/albert_maskfill.py new file mode 100644 index 00000000..e39005ea --- /dev/null +++ b/shark_web/models/albert_maskfill.py @@ -0,0 +1,69 @@ +from transformers import AutoModelForMaskedLM, AutoTokenizer +import torch +from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter +import numpy as np + +MAX_SEQUENCE_LENGTH = 512 +BATCH_SIZE = 1 + +class AlbertModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = AutoModelForMaskedLM.from_pretrained("albert-base-v2") + self.model.eval() + + def forward(self, input_ids, attention_mask): + return self.model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + + +################################## Preprocessing inputs and model ############ + +tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") + +def preprocess_data(text): + # Preparing Data + encoded_inputs = tokenizer( + text, + padding="max_length", + truncation=True, + max_length=MAX_SEQUENCE_LENGTH, + return_tensors="pt", + ) + inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"]) + return inputs + +def top5_possibilities(text, inputs, token_logits): + mask_id = torch.where( + inputs[0] == tokenizer.mask_token_id + )[1] + mask_token_logits = token_logits[0, mask_id, :] + percentage = torch.nn.functional.softmax(mask_token_logits, dim=1)[0] + top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() + top5 = {} + for token in top_5_tokens: + label = text.replace(tokenizer.mask_token, tokenizer.decode(token)) + top5[label] = percentage[token].item() + return top5 + +############################################################################## + +def albert_maskfill_inf(masked_text): + inputs = preprocess_data(masked_text) + mlir_importer = SharkImporter( + AlbertModule(), + inputs, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=False, tracing_required=True + ) + shark_module = SharkInference( + minilm_mlir, func_name, mlir_dialect="linalg" + ) + shark_module.compile() + token_logits = torch.tensor(shark_module.forward(inputs)) + return top5_possibilities(masked_text, inputs, token_logits) + diff --git a/shark_web/models/resnet50.py b/shark_web/models/resnet50.py new file mode 100644 index 00000000..9166287f --- /dev/null +++ b/shark_web/models/resnet50.py @@ -0,0 +1,54 @@ +from PIL import Image +import requests +import torch +from torchvision import transforms +from shark.shark_inference import SharkInference +from shark.shark_downloader import download_torch_model + +################################## Preprocessing inputs and model ############ + +def preprocess_image(img): + image = Image.fromarray(img) + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + img_preprocessed = preprocess(image) + return torch.unsqueeze(img_preprocessed, 0) + + +def load_labels(): + classes_text = requests.get( + "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", + stream=True, + ).text + labels = [line.strip() for line in classes_text.splitlines()] + return labels + + +def top3_possibilities(res): + labels = load_labels() + _, indexes = torch.sort(res, descending=True) + percentage = torch.nn.functional.softmax(res, dim=1)[0] + top3 = dict([(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]) + return top3 + +############################################################################## + +def resnet_inf(numpy_img): + img = preprocess_image(numpy_img) + ## Can pass any img or input to the forward module. + mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50") + + shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg") + shark_module.compile() + result = shark_module.forward((img.detach().numpy(),)) + + # print("The top 3 results obtained via shark_runner is:") + return top3_possibilities(torch.from_numpy(result))