Add tf/torch/mhlo/tosa support for SharkDownloader (#151)

This commit is contained in:
Chi_Liu
2022-06-22 11:25:34 -07:00
committed by GitHub
parent e8aa105b2a
commit a635b6fbef
2 changed files with 63 additions and 33 deletions

View File

@@ -40,9 +40,12 @@ class SharkDownloader:
self.local_tank_dir = local_tank_dir
self.tank_url = tank_url
self.model_type = model_type
self.input_json = input_json
self.input_type = input_type_to_np_dtype[input_type]
self.input_json = input_json # optional if you don't have input
self.input_type = input_type_to_np_dtype[
input_type
] # optional if you don't have input
self.mlir_file = None # .mlir file local address.
self.mlir_url = None
self.inputs = None # Input has to be (list of np.array) for sharkInference.forward use
self.mlir_model = []
@@ -73,51 +76,78 @@ class SharkDownloader:
]
else:
print(
"No json input required for current model. You could call setup_inputs(you_inputs)."
"No json input required for current model type. You could call setup_inputs(YOU_INPUTS)."
)
return self.inputs
def load_mlir_model(self):
if self.model_type in ["tflite-tosa"]:
workdir = os.path.join(
os.path.dirname(__file__), self.local_tank_dir
workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir)
os.makedirs(workdir, exist_ok=True)
print(f"TMP_MODEL_DIR = {workdir}")
# use model name get dir.
model_name_dir = os.path.join(workdir, str(self.model_name))
if not os.path.exists(model_name_dir):
print(
"Model has not been download."
"shark_downloader will automatically download by tank_url if provided."
" You can also manually to download the model from shark_tank by yourself."
)
os.makedirs(workdir, exist_ok=True)
print(f"TMP_MODEL_DIR = {workdir}")
os.makedirs(model_name_dir, exist_ok=True)
print(f"TMP_MODELNAME_DIR = {model_name_dir}")
# use model name get dir.
model_name_dir = os.path.join(workdir, str(self.model_name))
if not os.path.exists(model_name_dir):
print(
"Model has not been download."
"shark_downloader will automatically download by tank_url if provided."
" You can also manually to download the model from shark_tank by yourself."
)
os.makedirs(model_name_dir, exist_ok=True)
print(f"TMP_MODELNAME_DIR = {model_name_dir}")
mlir_url = (
if self.model_type in ["tflite-tosa"]:
self.mlir_url = (
self.tank_url
+ "/tflite/"
+ "/"
+ str(self.model_name)
+ "/"
+ str(self.model_name)
+ "_tosa.mlir"
+ "_tflite.mlir"
)
self.mlir_file = "/".join(
[model_name_dir, str(self.model_name) + "_tosa.mlir"]
[model_name_dir, str(self.model_name) + "_tfite.mlir"]
)
elif self.model_type in ["tensorflow"]:
self.mlir_url = (
self.tank_url
+ "/"
+ str(self.model_name)
+ "/"
+ str(self.model_name)
+ "_tf.mlir"
)
self.mlir_file = "/".join(
[model_name_dir, str(self.model_name) + "_tf.mlir"]
)
elif self.model_type in ["torch", "jax", "mhlo", "tosa"]:
self.mlir_url = (
self.tank_url
+ "/"
+ str(self.model_name)
+ "/"
+ str(self.model_name)
+ "_"
+ str(self.model_type)
+ ".mlir"
)
self.mlir_file = "/".join(
[
model_name_dir,
str(self.model_name) + "_" + str(self.model_type) + ".mlir",
]
)
if os.path.exists(self.mlir_file):
print("Model has been downloaded before.", self.mlir_file)
else:
print("Download mlir model", mlir_url)
urllib.request.urlretrieve(mlir_url, self.mlir_file)
print("Get tosa.mlir model return")
with open(self.mlir_file) as f:
self.mlir_model = f.read()
else:
print("Unsupported mlir model")
if os.path.exists(self.mlir_file):
print("Model has been downloaded before.", self.mlir_file)
else:
print("Download mlir model", self.mlir_url)
urllib.request.urlretrieve(self.mlir_url, self.mlir_file)
print("Get .mlir model return")
with open(self.mlir_file) as f:
self.mlir_model = f.read()
return self.mlir_model
def setup_inputs(self, inputs):

View File

@@ -24,7 +24,7 @@ class AlbertTfliteModuleTester:
self.shark_downloader = SharkDownloader(
model_name="albert_lite_base",
tank_url="https://storage.googleapis.com/shark_tank",
local_tank_dir="./../gen_shark_tank/tflite",
local_tank_dir="./../gen_shark_tank",
model_type="tflite-tosa",
input_json="input.json",
input_type="int32",