mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
generate shark tank for tflite (#173)
* Add gen_shark_tank support tflite * gen_shark_tank.py use SharkImporter to save model
This commit is contained in:
@@ -15,8 +15,7 @@ import os
|
||||
import urllib.request
|
||||
import csv
|
||||
import argparse
|
||||
import iree.compiler.tflite as ireec_tflite
|
||||
from shark.iree_utils._common import IREE_TARGET_MAP
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
|
||||
class SharkTank:
|
||||
@@ -32,77 +31,78 @@ class SharkTank:
|
||||
self.tflite_model_list = tflite_model_list
|
||||
self.upload = upload
|
||||
|
||||
print("Setting up for TMP_DIR")
|
||||
self.workdir = os.path.join(os.path.dirname(__file__), "./gen_shark_tank")
|
||||
print(f"tflite TMP_shark_tank_DIR = {self.workdir}")
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
|
||||
if self.torch_model_list is not None:
|
||||
print("Process torch model")
|
||||
else:
|
||||
print("Torch sharktank not implemented yet")
|
||||
self.save_torch_model()
|
||||
|
||||
if self.tf_model_list is not None:
|
||||
print("Process torch model")
|
||||
else:
|
||||
print("tf sharktank not implemented yet")
|
||||
self.save_tf_model()
|
||||
|
||||
print("self.tflite_model_list: ", self.tflite_model_list)
|
||||
# compile and run tfhub tflite
|
||||
if self.tflite_model_list is not None:
|
||||
print("Setting up for tflite TMP_DIR")
|
||||
self.tflite_workdir = os.path.join(os.path.dirname(__file__), "./gen_shark_tank")
|
||||
print(f"tflite TMP_shark_tank_DIR = {self.tflite_workdir}")
|
||||
os.makedirs(self.tflite_workdir, exist_ok=True)
|
||||
self.save_tflite_model()
|
||||
|
||||
with open(self.tflite_model_list) as csvfile:
|
||||
tflite_reader = csv.reader(csvfile, delimiter=",")
|
||||
for row in tflite_reader:
|
||||
tflite_model_name = row[0]
|
||||
tflite_model_link = row[1]
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(self.tflite_workdir, str(tflite_model_name))
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
|
||||
tflite_saving_file = "/".join(
|
||||
[
|
||||
tflite_model_name_dir,
|
||||
str(tflite_model_name) + "_tflite.tflite",
|
||||
]
|
||||
)
|
||||
tflite_tosa_file = "/".join(
|
||||
[
|
||||
tflite_model_name_dir,
|
||||
str(tflite_model_name) + "_tflite.mlir",
|
||||
]
|
||||
)
|
||||
self.binary = "/".join(
|
||||
[
|
||||
tflite_model_name_dir,
|
||||
str(tflite_model_name) + "_module.bytecode",
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Setting up local address for tflite model file: ",
|
||||
tflite_saving_file,
|
||||
)
|
||||
if os.path.exists(tflite_saving_file):
|
||||
print(tflite_saving_file, "exists")
|
||||
else:
|
||||
print("Download tflite model")
|
||||
urllib.request.urlretrieve(str(tflite_model_link), tflite_saving_file)
|
||||
|
||||
if os.path.exists(tflite_tosa_file):
|
||||
print("Exists", tflite_tosa_file)
|
||||
else:
|
||||
print("Convert tflite to tosa.mlir")
|
||||
ireec_tflite.compile_file(
|
||||
tflite_saving_file,
|
||||
input_type="tosa",
|
||||
save_temp_iree_input=tflite_tosa_file,
|
||||
target_backends=[IREE_TARGET_MAP["cpu"]],
|
||||
import_only=False,
|
||||
)
|
||||
|
||||
if self.upload == True:
|
||||
if self.upload:
|
||||
print("upload tmp tank to gcp")
|
||||
os.system("gsutil cp -r ./gen_shark_tank gs://shark_tank/")
|
||||
|
||||
def save_torch_model(self):
|
||||
from tank.model_utils import get_hf_model
|
||||
|
||||
print("Torch sharktank not implemented yet")
|
||||
|
||||
def save_tf_model(self):
|
||||
print("tf sharktank not implemented yet")
|
||||
|
||||
def save_tflite_model(self):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(self.tflite_model_list) as csvfile:
|
||||
tflite_reader = csv.reader(csvfile, delimiter=",")
|
||||
for row in tflite_reader:
|
||||
tflite_model_name = row[0]
|
||||
tflite_model_link = row[1]
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(self.workdir, str(tflite_model_name))
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
tflite_tosa_file = "/".join(
|
||||
[
|
||||
tflite_model_name_dir,
|
||||
str(tflite_model_name) + "_tflite.mlir",
|
||||
]
|
||||
)
|
||||
|
||||
# Preprocess to get SharkImporter input args
|
||||
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
# Use SharkImporter to get SharkInference input args
|
||||
my_shark_importer = SharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
frontend="tflite",
|
||||
raw_model_file=raw_model_file_path,
|
||||
)
|
||||
mlir_model, func_name = my_shark_importer.import_mlir()
|
||||
|
||||
if os.path.exists(tflite_tosa_file):
|
||||
print("Exists", tflite_tosa_file)
|
||||
else:
|
||||
mlir_str = mlir_model.decode("utf-8")
|
||||
with open(tflite_tosa_file, "w") as f:
|
||||
f.write(mlir_str)
|
||||
print(f"Saved mlir in {tflite_tosa_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -63,6 +63,7 @@ class TFLitePreprocessor:
|
||||
self.mlir_model = None # read of .mlir file
|
||||
self.output_tensor = None # the raw tf/pytorch/tflite_output_tensor, not mlir_tensor
|
||||
self.interpreter = None # could be tflite/tf/torch_interpreter in utils
|
||||
self.input_file = None
|
||||
|
||||
# create tmp model file directory
|
||||
if self.model_path is None and self.model_name is None:
|
||||
@@ -81,7 +82,7 @@ class TFLitePreprocessor:
|
||||
return
|
||||
|
||||
if (self.input_details is None) or (self.output_details is None):
|
||||
print("Setting up tflite interpreter to get model input details")
|
||||
# print("Setting up tflite interpreter to get model input details")
|
||||
self.setup_interpreter()
|
||||
|
||||
inputs = self.generate_inputs(self.input_details) # device_inputs
|
||||
@@ -96,6 +97,7 @@ class TFLitePreprocessor:
|
||||
|
||||
self.raw_model_file = "/".join([tflite_model_name_dir, str(self.model_name) + "_tflite.tflite"])
|
||||
self.mlir_file = "/".join([tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"])
|
||||
self.input_file = "/".join([tflite_model_name_dir, "input.json"])
|
||||
|
||||
if os.path.exists(self.raw_model_file):
|
||||
print(
|
||||
@@ -121,24 +123,6 @@ class TFLitePreprocessor:
|
||||
print("Error, No model path find in tflite_model_list.csv")
|
||||
return False
|
||||
urllib.request.urlretrieve(self.model_path, self.raw_model_file)
|
||||
# if os.path.exists(self.mlir_file):
|
||||
# print("Exists MLIR model ", self.mlir_file)
|
||||
# else:
|
||||
# print(
|
||||
# "No tflite tosa.mlir, please use python generate_sharktank.py to download tosa model"
|
||||
# )
|
||||
# print("Convert tflite to tosa.mlir")
|
||||
# import iree.compiler.tflite as ireec_tflite
|
||||
#
|
||||
# ireec_tflite.compile_file(
|
||||
# self.raw_model_file,
|
||||
# input_type="tosa",
|
||||
# save_temp_iree_input=self.mlir_file,
|
||||
# target_backends=[IREE_TARGET_MAP["cpu"]],
|
||||
# import_only=False,
|
||||
# )
|
||||
# with open(self.mlir_file) as f:
|
||||
# self.mlir_model = f.read()
|
||||
return True
|
||||
|
||||
def setup_interpreter(self):
|
||||
@@ -151,19 +135,19 @@ class TFLitePreprocessor:
|
||||
def generate_inputs(self, input_details):
|
||||
self.inputs = []
|
||||
for tmp_input in input_details:
|
||||
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
|
||||
# print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
|
||||
self.inputs.append(np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"]))
|
||||
# save inputs into json file
|
||||
tmp_json = []
|
||||
for tmp_input in input_details:
|
||||
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
|
||||
# print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
|
||||
tmp_json.append(np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"]).tolist())
|
||||
with open("input1.json", "w") as f:
|
||||
with open(self.input_file, "w") as f:
|
||||
json.dump(tmp_json, f)
|
||||
return self.inputs
|
||||
|
||||
def setup_inputs(self, inputs):
|
||||
print("Setting up inputs")
|
||||
# print("Setting up inputs")
|
||||
self.inputs = inputs
|
||||
|
||||
def get_mlir_model(self):
|
||||
|
||||
@@ -14,4 +14,4 @@ ssd_mobilenet_v2_fpnlite_dynamic_1.0_float, https://storage.googleapis.com/iree-
|
||||
ssd_mobilenet_v2_fpnlite_dynamic_1.0_uint8, https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_fpnlite_dynamic_1.0_uint8.tflite
|
||||
ssd_mobilenet_v2_dynamic_1.0_int8, https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_dynamic_1.0_int8.tflite
|
||||
ssd_mobilenet_v2_dynamic_1.0_float, https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_dynamic_1.0_float.tflite
|
||||
|
||||
person_detect, https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite
|
||||
|
@@ -27,7 +27,6 @@ mobilenet_v3-large_224_1.0_float, https://storage.googleapis.com/iree-model-arti
|
||||
mobilenet_v3-large_224_1.0_uint8, https://storage.googleapis.com/iree-model-artifacts/mobilenet_v3-large_224_1.0_uint8.tflite
|
||||
mobilenet_v3.5multiavg_1.00_224_int8, https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v3.5multiavg_1.0_int8/mobilenet_v3.5multiavg_1.00_224_int8.tflite
|
||||
nasnet, https://tfhub.dev/tensorflow/lite-model/nasnet/large/1/default/1?lite-format=tflite
|
||||
person_detect, https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite
|
||||
multi_person_mobilenet_v1_075_float, https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/multi_person_mobilenet_v1_075_float.tflite
|
||||
resnet_50_224_int8, https://storage.googleapis.com/tf_model_garden/vision/resnet50_imagenet/resnet_50_224_int8.tflite
|
||||
squeezenet, https://tfhub.dev/tensorflow/lite-model/squeezenet/1/default/1?lite-format=tflite
|
||||
|
||||
|
Reference in New Issue
Block a user