generate shark tank for tflite (#173)

* Add gen_shark_tank support tflite

* gen_shark_tank.py use SharkImporter to save model
This commit is contained in:
Chi_Liu
2022-07-05 23:11:19 -07:00
committed by GitHub
parent c351bb50b6
commit 1cad50d521
4 changed files with 71 additions and 88 deletions

View File

@@ -15,8 +15,7 @@ import os
import urllib.request
import csv
import argparse
import iree.compiler.tflite as ireec_tflite
from shark.iree_utils._common import IREE_TARGET_MAP
from shark.shark_importer import SharkImporter
class SharkTank:
@@ -32,77 +31,78 @@ class SharkTank:
self.tflite_model_list = tflite_model_list
self.upload = upload
print("Setting up for TMP_DIR")
self.workdir = os.path.join(os.path.dirname(__file__), "./gen_shark_tank")
print(f"tflite TMP_shark_tank_DIR = {self.workdir}")
os.makedirs(self.workdir, exist_ok=True)
if self.torch_model_list is not None:
print("Process torch model")
else:
print("Torch sharktank not implemented yet")
self.save_torch_model()
if self.tf_model_list is not None:
print("Process torch model")
else:
print("tf sharktank not implemented yet")
self.save_tf_model()
print("self.tflite_model_list: ", self.tflite_model_list)
# compile and run tfhub tflite
if self.tflite_model_list is not None:
print("Setting up for tflite TMP_DIR")
self.tflite_workdir = os.path.join(os.path.dirname(__file__), "./gen_shark_tank")
print(f"tflite TMP_shark_tank_DIR = {self.tflite_workdir}")
os.makedirs(self.tflite_workdir, exist_ok=True)
self.save_tflite_model()
with open(self.tflite_model_list) as csvfile:
tflite_reader = csv.reader(csvfile, delimiter=",")
for row in tflite_reader:
tflite_model_name = row[0]
tflite_model_link = row[1]
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(self.tflite_workdir, str(tflite_model_name))
os.makedirs(tflite_model_name_dir, exist_ok=True)
tflite_saving_file = "/".join(
[
tflite_model_name_dir,
str(tflite_model_name) + "_tflite.tflite",
]
)
tflite_tosa_file = "/".join(
[
tflite_model_name_dir,
str(tflite_model_name) + "_tflite.mlir",
]
)
self.binary = "/".join(
[
tflite_model_name_dir,
str(tflite_model_name) + "_module.bytecode",
]
)
print(
"Setting up local address for tflite model file: ",
tflite_saving_file,
)
if os.path.exists(tflite_saving_file):
print(tflite_saving_file, "exists")
else:
print("Download tflite model")
urllib.request.urlretrieve(str(tflite_model_link), tflite_saving_file)
if os.path.exists(tflite_tosa_file):
print("Exists", tflite_tosa_file)
else:
print("Convert tflite to tosa.mlir")
ireec_tflite.compile_file(
tflite_saving_file,
input_type="tosa",
save_temp_iree_input=tflite_tosa_file,
target_backends=[IREE_TARGET_MAP["cpu"]],
import_only=False,
)
if self.upload == True:
if self.upload:
print("upload tmp tank to gcp")
os.system("gsutil cp -r ./gen_shark_tank gs://shark_tank/")
def save_torch_model(self):
from tank.model_utils import get_hf_model
print("Torch sharktank not implemented yet")
def save_tf_model(self):
print("tf sharktank not implemented yet")
def save_tflite_model(self):
from shark.tflite_utils import TFLitePreprocessor
with open(self.tflite_model_list) as csvfile:
tflite_reader = csv.reader(csvfile, delimiter=",")
for row in tflite_reader:
tflite_model_name = row[0]
tflite_model_link = row[1]
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(self.workdir, str(tflite_model_name))
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
tflite_tosa_file = "/".join(
[
tflite_model_name_dir,
str(tflite_model_name) + "_tflite.mlir",
]
)
# Preprocess to get SharkImporter input args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
frontend="tflite",
raw_model_file=raw_model_file_path,
)
mlir_model, func_name = my_shark_importer.import_mlir()
if os.path.exists(tflite_tosa_file):
print("Exists", tflite_tosa_file)
else:
mlir_str = mlir_model.decode("utf-8")
with open(tflite_tosa_file, "w") as f:
f.write(mlir_str)
print(f"Saved mlir in {tflite_tosa_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()

View File

@@ -63,6 +63,7 @@ class TFLitePreprocessor:
self.mlir_model = None # read of .mlir file
self.output_tensor = None # the raw tf/pytorch/tflite_output_tensor, not mlir_tensor
self.interpreter = None # could be tflite/tf/torch_interpreter in utils
self.input_file = None
# create tmp model file directory
if self.model_path is None and self.model_name is None:
@@ -81,7 +82,7 @@ class TFLitePreprocessor:
return
if (self.input_details is None) or (self.output_details is None):
print("Setting up tflite interpreter to get model input details")
# print("Setting up tflite interpreter to get model input details")
self.setup_interpreter()
inputs = self.generate_inputs(self.input_details) # device_inputs
@@ -96,6 +97,7 @@ class TFLitePreprocessor:
self.raw_model_file = "/".join([tflite_model_name_dir, str(self.model_name) + "_tflite.tflite"])
self.mlir_file = "/".join([tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"])
self.input_file = "/".join([tflite_model_name_dir, "input.json"])
if os.path.exists(self.raw_model_file):
print(
@@ -121,24 +123,6 @@ class TFLitePreprocessor:
print("Error, No model path find in tflite_model_list.csv")
return False
urllib.request.urlretrieve(self.model_path, self.raw_model_file)
# if os.path.exists(self.mlir_file):
# print("Exists MLIR model ", self.mlir_file)
# else:
# print(
# "No tflite tosa.mlir, please use python generate_sharktank.py to download tosa model"
# )
# print("Convert tflite to tosa.mlir")
# import iree.compiler.tflite as ireec_tflite
#
# ireec_tflite.compile_file(
# self.raw_model_file,
# input_type="tosa",
# save_temp_iree_input=self.mlir_file,
# target_backends=[IREE_TARGET_MAP["cpu"]],
# import_only=False,
# )
# with open(self.mlir_file) as f:
# self.mlir_model = f.read()
return True
def setup_interpreter(self):
@@ -151,19 +135,19 @@ class TFLitePreprocessor:
def generate_inputs(self, input_details):
self.inputs = []
for tmp_input in input_details:
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
# print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
self.inputs.append(np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"]))
# save inputs into json file
tmp_json = []
for tmp_input in input_details:
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
# print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
tmp_json.append(np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"]).tolist())
with open("input1.json", "w") as f:
with open(self.input_file, "w") as f:
json.dump(tmp_json, f)
return self.inputs
def setup_inputs(self, inputs):
print("Setting up inputs")
# print("Setting up inputs")
self.inputs = inputs
def get_mlir_model(self):

View File

@@ -14,4 +14,4 @@ ssd_mobilenet_v2_fpnlite_dynamic_1.0_float, https://storage.googleapis.com/iree-
ssd_mobilenet_v2_fpnlite_dynamic_1.0_uint8, https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_fpnlite_dynamic_1.0_uint8.tflite
ssd_mobilenet_v2_dynamic_1.0_int8, https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_dynamic_1.0_int8.tflite
ssd_mobilenet_v2_dynamic_1.0_float, https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_dynamic_1.0_float.tflite
person_detect, https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite
1 ASR_TFLite https://tfhub.dev/neso613/lite-model/ASR_TFLite/pre_trained_models/English/1?lite-format=tflite
14 ssd_mobilenet_v2_fpnlite_dynamic_1.0_uint8 https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_fpnlite_dynamic_1.0_uint8.tflite
15 ssd_mobilenet_v2_dynamic_1.0_int8 https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_dynamic_1.0_int8.tflite
16 ssd_mobilenet_v2_dynamic_1.0_float https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_dynamic_1.0_float.tflite
17 person_detect https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite

View File

@@ -27,7 +27,6 @@ mobilenet_v3-large_224_1.0_float, https://storage.googleapis.com/iree-model-arti
mobilenet_v3-large_224_1.0_uint8, https://storage.googleapis.com/iree-model-artifacts/mobilenet_v3-large_224_1.0_uint8.tflite
mobilenet_v3.5multiavg_1.00_224_int8, https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v3.5multiavg_1.0_int8/mobilenet_v3.5multiavg_1.00_224_int8.tflite
nasnet, https://tfhub.dev/tensorflow/lite-model/nasnet/large/1/default/1?lite-format=tflite
person_detect, https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite
multi_person_mobilenet_v1_075_float, https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/multi_person_mobilenet_v1_075_float.tflite
resnet_50_224_int8, https://storage.googleapis.com/tf_model_garden/vision/resnet50_imagenet/resnet_50_224_int8.tflite
squeezenet, https://tfhub.dev/tensorflow/lite-model/squeezenet/1/default/1?lite-format=tflite
1 albert_lite_base https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite
27 mobilenet_v3-large_224_1.0_uint8 https://storage.googleapis.com/iree-model-artifacts/mobilenet_v3-large_224_1.0_uint8.tflite
28 mobilenet_v3.5multiavg_1.00_224_int8 https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v3.5multiavg_1.0_int8/mobilenet_v3.5multiavg_1.00_224_int8.tflite
29 nasnet https://tfhub.dev/tensorflow/lite-model/nasnet/large/1/default/1?lite-format=tflite
person_detect https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite
30 multi_person_mobilenet_v1_075_float https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/multi_person_mobilenet_v1_075_float.tflite
31 resnet_50_224_int8 https://storage.googleapis.com/tf_model_garden/vision/resnet50_imagenet/resnet_50_224_int8.tflite
32 squeezenet https://tfhub.dev/tensorflow/lite-model/squeezenet/1/default/1?lite-format=tflite