Add support for TensorFlow modules + TF miniLM example (#51)

This commit is contained in:
Ean Garvey
2022-05-25 20:50:43 -05:00
committed by GitHub
parent 737be5be09
commit 772f60c313
5 changed files with 343 additions and 1 deletions

View File

@@ -0,0 +1,71 @@
import sys
from absl import app
import time
import numpy as np
import os
import tempfile
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_classifier
from shark.shark_runner_tf import SharkTrainerTF
vocab_size = 100
NUM_CLASSES = 5
SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
# Create a set of 2-dimensional inputs
bert_input = [tf.TensorSpec(shape=[BATCH_SIZE,SEQUENCE_LENGTH],dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE,SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE,SEQUENCE_LENGTH], dtype=tf.int32),]
class BertModule(tf.Module):
def __init__(self):
super(BertModule, self).__init__()
dict_outputs = False
test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=NUM_CLASSES)
bert_trainer_model.summary()
# Invoke the trainer model on the inputs. This causes the layer to be built.
self.m = bert_trainer_model
self.m.predict = lambda x: self.m.call(x, training=False)
self.predict = tf.function(
input_signature=[bert_input])(self.m.predict)
self.m.learn = lambda x,y: self.m.call(x, training=False)
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
@tf.function(input_signature=[
bert_input, # inputs
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32) # labels
])
def forward(self,inputs,labels):
with tf.GradientTape() as tape:
# Capture the gradients from forward prop...
probs = self.m(inputs, training=True)
loss = self.loss(labels, probs)
# ...and use them to update the model's weights.
variables = self.m.trainable_variables
gradients = tape.gradient(loss, variables)
self.optimizer.apply_gradients(zip(gradients, variables))
return probs
if __name__ == "__main__":
predict_sample_input = np.asarray([np.random.randint(5, size=(BATCH_SIZE,SEQUENCE_LENGTH)), np.random.randint(5, size=(BATCH_SIZE,SEQUENCE_LENGTH)), np.random.randint(5, size=(BATCH_SIZE,SEQUENCE_LENGTH))])
sample_input_tensors = tf.convert_to_tensor(predict_sample_input)
num_iter = 10
shark_module=SharkTrainerTF(BertModule(), (sample_input_tensors, tf.convert_to_tensor(np.random.randint(5, size=(BATCH_SIZE)))))
start = time.time()
print(shark_module.train(num_iter))
end = time.time()
total_time = end - start
print("time: "+str(total_time))
print("time/iter: "+str(total_time/num_iter))

View File

@@ -0,0 +1,38 @@
import tensorflow as tf
from transformers import BertModel, BertTokenizer, TFBertModel
from shark.shark_runner_tf import SharkInferenceTF
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
# Create a set of 2-dimensional inputs
bert_input = [tf.TensorSpec(shape=[BATCH_SIZE,MAX_SEQUENCE_LENGTH],dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE,MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE,MAX_SEQUENCE_LENGTH], dtype=tf.int32)]
class BertModule(tf.Module):
def __init__(self):
super(BertModule, self).__init__()
# Create a BERT trainer with the created network.
self.m = TFBertModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased", from_pt=True)
# Invoke the trainer model on the inputs. This causes the layer to be built.
self.m.predict = lambda x,y,z: self.m.call(input_ids=x, attention_mask=y, token_type_ids=z, training=False)
@tf.function(input_signature=bert_input)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)
if __name__ == "__main__":
# Prepping Data
tokenizer = BertTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, padding='max_length', truncation=True, max_length=MAX_SEQUENCE_LENGTH)
for key in encoded_input:
encoded_input[key] = tf.expand_dims(tf.convert_to_tensor(encoded_input[key]),0)
shark_module = SharkInferenceTF(
BertModule(),
(encoded_input["input_ids"], encoded_input["attention_mask"], encoded_input["token_type_ids"])
)
print(shark_module.forward((encoded_input["input_ids"], encoded_input["attention_mask"], encoded_input["token_type_ids"])))