Add the shark_downloader for the torch_models. (#184)

This commit is contained in:
Prashant Kumar
2022-07-15 02:11:43 +05:30
committed by GitHub
parent 8434c67d96
commit 1191f53c9d
2 changed files with 40 additions and 34 deletions

View File

@@ -1,41 +1,24 @@
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
from shark.shark_downloader import download_torch_model
class MiniLMSequenceClassification(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
)
def forward(self, x, y, z):
return self.model.forward(x, y, z)[0]
test_input = torch.randint(2, (1, 128)).to(torch.int32)
mlir_importer = SharkImporter(
MiniLMSequenceClassification(),
(test_input, test_input, test_input),
frontend="torch",
mlir_model, func_name, inputs, golden_out = download_torch_model(
"microsoft/MiniLM-L12-H384-uncased"
)
# torch hugging face models needs tracing..
mlir_importer.import_debug(tracing_required=True)
# print(golden_out)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)
print("The obtained result via shark is: ", result)
print("The golden result is:", golden_out)
# shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
# shark_module.compile()
# result = shark_module.forward((test_input, test_input, test_input))
# print("Obtained result", result)
# Let's generate random inputs, currently supported
# for static models.
rand_inputs = shark_module.generate_random_inputs()
rand_results = shark_module.forward(rand_inputs)
print("Running shark_module with random_inputs is: ", rand_results)