mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
Add the shark_downloader for the torch_models. (#184)
This commit is contained in:
@@ -1,41 +1,24 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.model.forward(x, y, z)[0]
|
||||
|
||||
|
||||
test_input = torch.randint(2, (1, 128)).to(torch.int32)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
MiniLMSequenceClassification(),
|
||||
(test_input, test_input, test_input),
|
||||
frontend="torch",
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
|
||||
# torch hugging face models needs tracing..
|
||||
mlir_importer.import_debug(tracing_required=True)
|
||||
|
||||
# print(golden_out)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
|
||||
# shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
# shark_module.compile()
|
||||
# result = shark_module.forward((test_input, test_input, test_input))
|
||||
# print("Obtained result", result)
|
||||
|
||||
# Let's generate random inputs, currently supported
|
||||
# for static models.
|
||||
rand_inputs = shark_module.generate_random_inputs()
|
||||
rand_results = shark_module.forward(rand_inputs)
|
||||
|
||||
print("Running shark_module with random_inputs is: ", rand_results)
|
||||
|
||||
Reference in New Issue
Block a user