mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 22:07:55 -05:00
36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
import torch
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
from amdshark.amdshark_inference import AMDSharkInference
|
|
|
|
torch.manual_seed(0)
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
|
|
|
|
|
class MiniLMSequenceClassification(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
|
num_labels=2, # The number of output labels--2 for binary classification.
|
|
output_attentions=False, # Whether the model returns attentions weights.
|
|
output_hidden_states=False, # Whether the model returns all hidden-states.
|
|
torchscript=True,
|
|
)
|
|
|
|
def forward(self, tokens):
|
|
return self.model.forward(tokens)[0]
|
|
|
|
|
|
test_input = torch.randint(2, (1, 128))
|
|
|
|
amdshark_module = AMDSharkInference(
|
|
MiniLMSequenceClassification(),
|
|
(test_input,),
|
|
jit_trace=True,
|
|
benchmark_mode=True,
|
|
)
|
|
|
|
amdshark_module.compile()
|
|
amdshark_module.forward((test_input,))
|
|
amdshark_module.benchmark_all((test_input,))
|