Mini LM Loader Example

-Add example to load miniLM from SharkHUB and benchmark.
-Modify TF benchmark to have growing GPU allocation.
-Add shark_load helper function
This commit is contained in:
stanley
2022-06-14 23:34:45 +00:00
parent 95c2e3d6ea
commit 14a56ca9b0
6 changed files with 62 additions and 9 deletions

View File

@@ -2,6 +2,10 @@ import tensorflow as tf
from transformers import BertModel, BertTokenizer, TFBertModel
from shark.shark_inference import SharkInference
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1

View File

@@ -0,0 +1,41 @@
import tensorflow as tf
from transformers import BertModel, BertTokenizer, TFBertModel
from shark.shark_inference import SharkInference
from shark.shark_importer import shark_load
from shark.parser import parser
import os
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
parser.add_argument(
"--download_mlir_path",
type=str,
default="minilm_tf_inference.mlir",
help="Specifies path to target mlir file that will be loaded.")
load_args, unknown = parser.parse_known_args()
MAX_SEQUENCE_LENGTH = 512
if __name__ == "__main__":
# Prepping Data
tokenizer = BertTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text,
padding='max_length',
truncation=True,
max_length=MAX_SEQUENCE_LENGTH)
for key in encoded_input:
encoded_input[key] = tf.expand_dims(
tf.convert_to_tensor(encoded_input[key]), 0)
model_name = "minilm_tf_inference"
minilm_mlir = shark_load(model_name, load_args.download_mlir_path)
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"],
encoded_input["token_type_ids"])
shark_module = SharkInference(
minilm_mlir, test_input, benchmark_mode=True)
shark_module.set_frontend("mhlo")
shark_module.compile()
shark_module.benchmark_all(test_input)

View File

@@ -5,7 +5,7 @@ import tensorflow as tf
from shark.shark_trainer import SharkTrainer
from shark.parser import parser
from urllib import request
from shark.shark_importer import shark_load
parser.add_argument(
"--download_mlir_path",
@@ -27,14 +27,10 @@ if __name__ == "__main__":
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
]
file_link = "https://storage.googleapis.com/shark_tank/users/stanley/bert_tf_training.mlir"
response = request.urlretrieve(file_link, load_args.download_mlir_path)
model_name = "bert_tf_training"
bert_mlir = shark_load(model_name, load_args.download_mlir_path)
sample_input_tensors = [tf.convert_to_tensor(val, dtype=tf.int32) for val in predict_sample_input]
num_iter = 10
if not os.path.isfile(load_args.download_mlir_path):
raise ValueError(f"Tried looking for target mlir in {load_args.download_mlir_path}, but cannot be found.")
with open(load_args.download_mlir_path, "rb") as input_file:
bert_mlir = input_file.read()
shark_module = SharkTrainer(
bert_mlir,
(sample_input_tensors,

View File

@@ -287,7 +287,7 @@ def tensor_to_type_str(input_tensors: tuple, frontend: str):
type_string = "x".join([str(dim) for dim in input_tensor.shape])
if frontend in ["torch", "pytorch"]:
dtype_string = str(input_tensor.dtype).replace("torch.", "")
elif frontend in ["tensorflow", "tf"]:
elif frontend in ["tensorflow", "tf", "mhlo"]:
dtype = input_tensor.dtype
dtype_string = re.findall('\'[^"]*\'',
str(dtype))[0].replace("\'", "")

View File

@@ -122,3 +122,15 @@ class SharkImporter:
print("Inference", self.model_source_hub, " not implemented yet")
elif self.model_source_hub == "jaxhub":
print("Inference", self.model_source_hub, " not implemented yet")
def shark_load(model_name, file_path):
file_link = f"https://storage.googleapis.com/shark_tank/users/stanley/{model_name}.mlir"
response = urllib.request.urlretrieve(file_link, file_path)
if not os.path.isfile(file_path):
raise ValueError(
f"Tried looking for target mlir in {file_path}, but cannot be found."
)
with open(file_path, "rb") as input_file:
model_mlir = input_file.read()
return model_mlir

View File

@@ -84,7 +84,7 @@ class SharkInference:
return self.shark_runner.forward(input_list, self.frontend)
# Saves the .vmfb module.
def save_module(self, dir = None):
def save_module(self, dir=None):
if dir is None:
return self.shark_runner.save_module()
return self.shark_runner.save_module(dir)