mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add tf image classification auto model (#213)
This commit is contained in:
@@ -80,6 +80,7 @@ def save_torch_model(torch_model_list):
|
||||
|
||||
def save_tf_model(tf_model_list):
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.tf.automodelimageclassification import get_causal_image_model
|
||||
|
||||
with open(tf_model_list) as csvfile:
|
||||
tf_reader = csv.reader(csvfile, delimiter=",")
|
||||
@@ -93,6 +94,8 @@ def save_tf_model(tf_model_list):
|
||||
print(model_type)
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_causal_lm_model(tf_model_name)
|
||||
if model_type == "img":
|
||||
model, input, _ = get_causal_image_model(tf_model_name)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
@@ -184,8 +187,8 @@ if __name__ == "__main__":
|
||||
if args.tf_model_csv:
|
||||
save_tf_model(args.tf_model_csv)
|
||||
|
||||
if args.tflite_model_csv:
|
||||
save_tflite_model(args.tflite_model_csv)
|
||||
# if args.tflite_model_csv:
|
||||
# save_tflite_model(args.tflite_model_csv)
|
||||
|
||||
if args.upload:
|
||||
print("uploading files to gs://shark_tank/")
|
||||
|
||||
@@ -197,8 +197,11 @@ class SharkImporter:
|
||||
golden_out = tuple(
|
||||
golden_out.numpy(),
|
||||
)
|
||||
else:
|
||||
elif golden_out is tuple:
|
||||
golden_out = self.convert_to_numpy(golden_out)
|
||||
else:
|
||||
# from transformers import TFSequenceClassifierOutput
|
||||
golden_out = golden_out.logits
|
||||
# Save the artifacts in the directory dir.
|
||||
self.save_data(
|
||||
dir,
|
||||
|
||||
97
tank/tf/automodelimageclassification.py
Normal file
97
tank/tf/automodelimageclassification.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from transformers import TFAutoModelForImageClassification
|
||||
from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor
|
||||
from transformers import BeitFeatureExtractor, AutoFeatureExtractor
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
import requests
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
|
||||
# Create a set of input signature.
|
||||
inputs_signature = [
|
||||
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
|
||||
]
|
||||
|
||||
|
||||
class AutoModelImageClassfication(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(AutoModelImageClassfication, self).__init__()
|
||||
self.m = TFAutoModelForImageClassification.from_pretrained(
|
||||
model_name, output_attentions=False
|
||||
)
|
||||
self.m.predict = lambda x: self.m(x)
|
||||
|
||||
@tf.function(input_signature=inputs_signature)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
|
||||
fail_models = [
|
||||
"facebook/data2vec-vision-base-ft1k",
|
||||
"microsoft/swin-tiny-patch4-window7-224",
|
||||
]
|
||||
|
||||
supported_models = [
|
||||
# "facebook/convnext-tiny-224",
|
||||
"google/vit-base-patch16-224",
|
||||
]
|
||||
|
||||
img_models_fe_dict = {
|
||||
"facebook/convnext-tiny-224": ConvNextFeatureExtractor,
|
||||
"facebook/data2vec-vision-base-ft1k": BeitFeatureExtractor,
|
||||
"microsoft/swin-tiny-patch4-window7-224": AutoFeatureExtractor,
|
||||
"google/vit-base-patch16-224": ViTFeatureExtractor,
|
||||
}
|
||||
|
||||
|
||||
def preprocess_input_image(model_name):
|
||||
# from datasets import load_dataset
|
||||
# dataset = load_dataset("huggingface/cats-image")
|
||||
# image1 = dataset["test"]["image"][0]
|
||||
# # print("image1: ", image1) # <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480 at 0x7FA0B86BB6D0>
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
# <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480 at 0x7FA0B86BB6D0>
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
feature_extractor = img_models_fe_dict[model_name].from_pretrained(
|
||||
model_name
|
||||
)
|
||||
# inputs: {'pixel_values': <tf.Tensor: shape=(1, 3, 224, 224), dtype=float32, numpy=array([[[[]]]], dtype=float32)>}
|
||||
inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
|
||||
return [inputs[str(*inputs)]]
|
||||
|
||||
|
||||
def get_causal_image_model(hf_name):
|
||||
model = AutoModelImageClassfication(hf_name)
|
||||
test_input = preprocess_input_image(hf_name)
|
||||
# TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(1, 1000), dtype=float32, numpy=
|
||||
# array([[]], dtype=float32)>, hidden_states=None, attentions=None)
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for model_name in supported_models:
|
||||
print(f"Running model: {model_name}")
|
||||
inputs = preprocess_input_image(model_name)
|
||||
model = AutoModelImageClassfication(model_name)
|
||||
|
||||
# 1. USE SharkImporter to get the mlir
|
||||
# from shark.shark_importer import SharkImporter
|
||||
# mlir_importer = SharkImporter(
|
||||
# model,
|
||||
# inputs,
|
||||
# frontend="tf",
|
||||
# )
|
||||
# imported_mlir, func_name = mlir_importer.import_mlir()
|
||||
|
||||
# 2. USE SharkDownloader to get the mlir
|
||||
imported_mlir, func_name, inputs, golden_out = download_tf_model(
|
||||
model_name
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
imported_mlir, func_name, device="cpu", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
shark_module.forward(inputs)
|
||||
@@ -14,3 +14,5 @@ roberta-base,hf
|
||||
xlm-roberta-base,hf
|
||||
microsoft/MiniLM-L12-H384-uncased,hf
|
||||
funnel-transformer/small,hf
|
||||
facebook/convnext-tiny-224,img
|
||||
google/vit-base-patch16-224,img
|
||||
|
||||
|
Reference in New Issue
Block a user