mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
Mini LM Loader Example
-Add example to load miniLM from SharkHUB and benchmark. -Modify TF benchmark to have growing GPU allocation. -Add shark_load helper function
This commit is contained in:
@@ -2,6 +2,10 @@ import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
|
||||
41
shark/examples/shark_inference/minilm_load_benchmark_tf.py
Normal file
41
shark/examples/shark_inference/minilm_load_benchmark_tf.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import shark_load
|
||||
from shark.parser import parser
|
||||
import os
|
||||
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
|
||||
parser.add_argument(
|
||||
"--download_mlir_path",
|
||||
type=str,
|
||||
default="minilm_tf_inference.mlir",
|
||||
help="Specifies path to target mlir file that will be loaded.")
|
||||
load_args, unknown = parser.parse_known_args()
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased")
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0)
|
||||
model_name = "minilm_tf_inference"
|
||||
minilm_mlir = shark_load(model_name, load_args.download_mlir_path)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"])
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, test_input, benchmark_mode=True)
|
||||
shark_module.set_frontend("mhlo")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
@@ -5,7 +5,7 @@ import tensorflow as tf
|
||||
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
from shark.parser import parser
|
||||
from urllib import request
|
||||
from shark.shark_importer import shark_load
|
||||
|
||||
parser.add_argument(
|
||||
"--download_mlir_path",
|
||||
@@ -27,14 +27,10 @@ if __name__ == "__main__":
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
|
||||
]
|
||||
file_link = "https://storage.googleapis.com/shark_tank/users/stanley/bert_tf_training.mlir"
|
||||
response = request.urlretrieve(file_link, load_args.download_mlir_path)
|
||||
model_name = "bert_tf_training"
|
||||
bert_mlir = shark_load(model_name, load_args.download_mlir_path)
|
||||
sample_input_tensors = [tf.convert_to_tensor(val, dtype=tf.int32) for val in predict_sample_input]
|
||||
num_iter = 10
|
||||
if not os.path.isfile(load_args.download_mlir_path):
|
||||
raise ValueError(f"Tried looking for target mlir in {load_args.download_mlir_path}, but cannot be found.")
|
||||
with open(load_args.download_mlir_path, "rb") as input_file:
|
||||
bert_mlir = input_file.read()
|
||||
shark_module = SharkTrainer(
|
||||
bert_mlir,
|
||||
(sample_input_tensors,
|
||||
|
||||
@@ -287,7 +287,7 @@ def tensor_to_type_str(input_tensors: tuple, frontend: str):
|
||||
type_string = "x".join([str(dim) for dim in input_tensor.shape])
|
||||
if frontend in ["torch", "pytorch"]:
|
||||
dtype_string = str(input_tensor.dtype).replace("torch.", "")
|
||||
elif frontend in ["tensorflow", "tf"]:
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
dtype = input_tensor.dtype
|
||||
dtype_string = re.findall('\'[^"]*\'',
|
||||
str(dtype))[0].replace("\'", "")
|
||||
|
||||
@@ -122,3 +122,15 @@ class SharkImporter:
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
elif self.model_source_hub == "jaxhub":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
|
||||
|
||||
def shark_load(model_name, file_path):
|
||||
file_link = f"https://storage.googleapis.com/shark_tank/users/stanley/{model_name}.mlir"
|
||||
response = urllib.request.urlretrieve(file_link, file_path)
|
||||
if not os.path.isfile(file_path):
|
||||
raise ValueError(
|
||||
f"Tried looking for target mlir in {file_path}, but cannot be found."
|
||||
)
|
||||
with open(file_path, "rb") as input_file:
|
||||
model_mlir = input_file.read()
|
||||
return model_mlir
|
||||
|
||||
@@ -84,7 +84,7 @@ class SharkInference:
|
||||
return self.shark_runner.forward(input_list, self.frontend)
|
||||
|
||||
# Saves the .vmfb module.
|
||||
def save_module(self, dir = None):
|
||||
def save_module(self, dir=None):
|
||||
if dir is None:
|
||||
return self.shark_runner.save_module()
|
||||
return self.shark_runner.save_module(dir)
|
||||
|
||||
Reference in New Issue
Block a user