From 0c31bb82cd30687b706dc8ab12bbb6026686ef57 Mon Sep 17 00:00:00 2001 From: Chi_Liu Date: Mon, 1 Aug 2022 10:40:49 -0700 Subject: [PATCH] Add hash of mlir for tf/tflite (#225) --- generate_sharktank.py | 14 +++++++++++ shark/shark_downloader.py | 52 ++++++++++++++++++++++++++++++++++++--- tank/tf/tf_model_list.csv | 1 + 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/generate_sharktank.py b/generate_sharktank.py index 8056cae9..3c6e3ede 100644 --- a/generate_sharktank.py +++ b/generate_sharktank.py @@ -124,6 +124,10 @@ def save_tf_model(tf_model_list): dir=tf_model_dir, model_name=tf_model_name, ) + mlir_hash = create_hash( + os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir") + ) + np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash)) def save_tflite_model(tflite_model_list): @@ -161,6 +165,16 @@ def save_tflite_model(tflite_model_list): model_name=tflite_model_name, func_name="main", ) + mlir_hash = create_hash( + os.path.join( + tflite_model_name_dir, + tflite_model_name + "_tflite" + ".mlir", + ) + ) + np.save( + os.path.join(tflite_model_name_dir, "hash"), + np.array(mlir_hash), + ) # Validates whether the file is present or not. diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index 6b144fc6..798dd109 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -128,9 +128,8 @@ def download_tflite_model(model_name, dynamic=False): dyn_str = "_dynamic" if dynamic else "" os.makedirs(WORKDIR, exist_ok=True) model_dir_name = model_name + "_tflite" - if not check_dir_exists( - model_dir_name, frontend="tflite", dynamic=dyn_str - ): + + def gs_download_model(): gs_command = ( 'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank' + "/" @@ -141,6 +140,29 @@ def download_tflite_model(model_name, dynamic=False): if os.system(gs_command) != 0: raise Exception("model not present in the tank. Contact Nod Admin") + if not check_dir_exists( + model_dir_name, frontend="tflite", dynamic=dyn_str + ): + gs_download_model() + else: + model_dir = os.path.join(WORKDIR, model_dir_name) + local_hash = str(np.load(os.path.join(model_dir, "hash.npy"))) + gs_hash = ( + 'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank' + + "/" + + model_dir_name + + "/hash.npy" + + " " + + os.path.join(model_dir, "upstream_hash.npy") + ) + if os.system(gs_hash) != 0: + raise Exception("hash of the model not present in the tank.") + upstream_hash = str( + np.load(os.path.join(model_dir, "upstream_hash.npy")) + ) + if local_hash != upstream_hash: + gs_download_model() + model_dir = os.path.join(WORKDIR, model_dir_name) with open( os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir") @@ -160,7 +182,8 @@ def download_tf_model(model_name): model_name = model_name.replace("/", "_") os.makedirs(WORKDIR, exist_ok=True) model_dir_name = model_name + "_tf" - if not check_dir_exists(model_dir_name, frontend="tf"): + + def gs_download_model(): gs_command = ( 'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank' + "/" @@ -171,6 +194,27 @@ def download_tf_model(model_name): if os.system(gs_command) != 0: raise Exception("model not present in the tank. Contact Nod Admin") + if not check_dir_exists(model_dir_name, frontend="tf"): + gs_download_model() + else: + model_dir = os.path.join(WORKDIR, model_dir_name) + local_hash = str(np.load(os.path.join(model_dir, "hash.npy"))) + gs_hash = ( + 'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank' + + "/" + + model_dir_name + + "/hash.npy" + + " " + + os.path.join(model_dir, "upstream_hash.npy") + ) + if os.system(gs_hash) != 0: + raise Exception("hash of the model not present in the tank.") + upstream_hash = str( + np.load(os.path.join(model_dir, "upstream_hash.npy")) + ) + if local_hash != upstream_hash: + gs_download_model() + model_dir = os.path.join(WORKDIR, model_dir_name) with open(os.path.join(model_dir, model_name + "_tf.mlir")) as f: mlir_file = f.read() diff --git a/tank/tf/tf_model_list.csv b/tank/tf/tf_model_list.csv index 53f3cdf6..8186185d 100644 --- a/tank/tf/tf_model_list.csv +++ b/tank/tf/tf_model_list.csv @@ -14,5 +14,6 @@ roberta-base,hf xlm-roberta-base,hf microsoft/MiniLM-L12-H384-uncased,hf funnel-transformer/small,hf +microsoft/mpnet-base,hf facebook/convnext-tiny-224,img google/vit-base-patch16-224,img