Add tf image classification auto model (#213)

This commit is contained in:
Chi_Liu
2022-07-26 23:18:42 -07:00
committed by GitHub
parent dc1a283ab7
commit af4257d05f
4 changed files with 108 additions and 3 deletions

View File

@@ -80,6 +80,7 @@ def save_torch_model(torch_model_list):
def save_tf_model(tf_model_list):
from tank.masked_lm_tf import get_causal_lm_model
from tank.tf.automodelimageclassification import get_causal_image_model
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")
@@ -93,6 +94,8 @@ def save_tf_model(tf_model_list):
print(model_type)
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
if model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
@@ -184,8 +187,8 @@ if __name__ == "__main__":
if args.tf_model_csv:
save_tf_model(args.tf_model_csv)
if args.tflite_model_csv:
save_tflite_model(args.tflite_model_csv)
# if args.tflite_model_csv:
# save_tflite_model(args.tflite_model_csv)
if args.upload:
print("uploading files to gs://shark_tank/")

View File

@@ -197,8 +197,11 @@ class SharkImporter:
golden_out = tuple(
golden_out.numpy(),
)
else:
elif golden_out is tuple:
golden_out = self.convert_to_numpy(golden_out)
else:
# from transformers import TFSequenceClassifierOutput
golden_out = golden_out.logits
# Save the artifacts in the directory dir.
self.save_data(
dir,

View File

@@ -0,0 +1,97 @@
from transformers import TFAutoModelForImageClassification
from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor
from transformers import BeitFeatureExtractor, AutoFeatureExtractor
import tensorflow as tf
from PIL import Image
import requests
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_tf_model
# Create a set of input signature.
inputs_signature = [
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
]
class AutoModelImageClassfication(tf.Module):
def __init__(self, model_name):
super(AutoModelImageClassfication, self).__init__()
self.m = TFAutoModelForImageClassification.from_pretrained(
model_name, output_attentions=False
)
self.m.predict = lambda x: self.m(x)
@tf.function(input_signature=inputs_signature)
def forward(self, inputs):
return self.m.predict(inputs)
fail_models = [
"facebook/data2vec-vision-base-ft1k",
"microsoft/swin-tiny-patch4-window7-224",
]
supported_models = [
# "facebook/convnext-tiny-224",
"google/vit-base-patch16-224",
]
img_models_fe_dict = {
"facebook/convnext-tiny-224": ConvNextFeatureExtractor,
"facebook/data2vec-vision-base-ft1k": BeitFeatureExtractor,
"microsoft/swin-tiny-patch4-window7-224": AutoFeatureExtractor,
"google/vit-base-patch16-224": ViTFeatureExtractor,
}
def preprocess_input_image(model_name):
# from datasets import load_dataset
# dataset = load_dataset("huggingface/cats-image")
# image1 = dataset["test"]["image"][0]
# # print("image1: ", image1) # <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480 at 0x7FA0B86BB6D0>
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480 at 0x7FA0B86BB6D0>
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = img_models_fe_dict[model_name].from_pretrained(
model_name
)
# inputs: {'pixel_values': <tf.Tensor: shape=(1, 3, 224, 224), dtype=float32, numpy=array([[[[]]]], dtype=float32)>}
inputs = feature_extractor(images=image, return_tensors="tf")
return [inputs[str(*inputs)]]
def get_causal_image_model(hf_name):
model = AutoModelImageClassfication(hf_name)
test_input = preprocess_input_image(hf_name)
# TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(1, 1000), dtype=float32, numpy=
# array([[]], dtype=float32)>, hidden_states=None, attentions=None)
actual_out = model.forward(*test_input)
return model, test_input, actual_out
if __name__ == "__main__":
for model_name in supported_models:
print(f"Running model: {model_name}")
inputs = preprocess_input_image(model_name)
model = AutoModelImageClassfication(model_name)
# 1. USE SharkImporter to get the mlir
# from shark.shark_importer import SharkImporter
# mlir_importer = SharkImporter(
# model,
# inputs,
# frontend="tf",
# )
# imported_mlir, func_name = mlir_importer.import_mlir()
# 2. USE SharkDownloader to get the mlir
imported_mlir, func_name, inputs, golden_out = download_tf_model(
model_name
)
shark_module = SharkInference(
imported_mlir, func_name, device="cpu", mlir_dialect="mhlo"
)
shark_module.compile()
shark_module.forward(inputs)

View File

@@ -14,3 +14,5 @@ roberta-base,hf
xlm-roberta-base,hf
microsoft/MiniLM-L12-H384-uncased,hf
funnel-transformer/small,hf
facebook/convnext-tiny-224,img
google/vit-base-patch16-224,img
1 model_name model_type
14 xlm-roberta-base hf
15 microsoft/MiniLM-L12-H384-uncased hf
16 funnel-transformer/small hf
17 facebook/convnext-tiny-224 img
18 google/vit-base-patch16-224 img