Files
SHARK-Studio/tank/tf/seq_classification.py

84 lines
2.5 KiB
Python
Executable File

#!/usr/bin/env python
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import tensorflow as tf
from shark.shark_inference import SharkInference
from shark.parser import shark_args
import argparse
seq_parser = argparse.ArgumentParser(
description="Shark Sequence Classification."
)
seq_parser.add_argument(
"--hf_model_name",
type=str,
default="bert-base-uncased",
help="Hugging face model to run sequence classification.",
)
seq_args, unknown = seq_parser.parse_known_args()
BATCH_SIZE = 1
MAX_SEQUENCE_LENGTH = 16
# Create a set of input signature.
inputs_signature = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
]
# For supported models please see here:
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForSequenceClassification
def preprocess_input(text="This is just used to compile the model"):
tokenizer = AutoTokenizer.from_pretrained(seq_args.hf_model_name)
inputs = tokenizer(
text,
padding="max_length",
return_tensors="tf",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
)
return inputs
class SeqClassification(tf.Module):
def __init__(self, model_name):
super(SeqClassification, self).__init__()
self.m = TFAutoModelForSequenceClassification.from_pretrained(
model_name, output_attentions=False, num_labels=2
)
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
@tf.function(input_signature=inputs_signature)
def forward(self, input_ids, attention_mask):
return tf.math.softmax(
self.m.predict(input_ids, attention_mask), axis=-1
)
if __name__ == "__main__":
inputs = preprocess_input()
shark_module = SharkInference(
SeqClassification(seq_args.hf_model_name),
(inputs["input_ids"], inputs["attention_mask"]),
)
shark_module.set_frontend("tensorflow")
shark_module.compile()
print(f"Model has been successfully compiled on {shark_args.device}")
while True:
input_text = input(
"Enter the text to classify (press q or nothing to exit): "
)
if not input_text or input_text == "q":
break
inputs = preprocess_input(input_text)
print(
shark_module.forward(
(inputs["input_ids"], inputs["attention_mask"])
)
)