Add hash of mlir for tf/tflite (#225)

This commit is contained in:
Chi_Liu
2022-08-01 10:40:49 -07:00
committed by GitHub
parent 315ec72984
commit 0c31bb82cd
3 changed files with 63 additions and 4 deletions

View File

@@ -124,6 +124,10 @@ def save_tf_model(tf_model_list):
dir=tf_model_dir,
model_name=tf_model_name,
)
mlir_hash = create_hash(
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
)
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
def save_tflite_model(tflite_model_list):
@@ -161,6 +165,16 @@ def save_tflite_model(tflite_model_list):
model_name=tflite_model_name,
func_name="main",
)
mlir_hash = create_hash(
os.path.join(
tflite_model_name_dir,
tflite_model_name + "_tflite" + ".mlir",
)
)
np.save(
os.path.join(tflite_model_name_dir, "hash"),
np.array(mlir_hash),
)
# Validates whether the file is present or not.

View File

@@ -128,9 +128,8 @@ def download_tflite_model(model_name, dynamic=False):
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_tflite"
if not check_dir_exists(
model_dir_name, frontend="tflite", dynamic=dyn_str
):
def gs_download_model():
gs_command = (
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
+ "/"
@@ -141,6 +140,29 @@ def download_tflite_model(model_name, dynamic=False):
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
if not check_dir_exists(
model_dir_name, frontend="tflite", dynamic=dyn_str
):
gs_download_model()
else:
model_dir = os.path.join(WORKDIR, model_dir_name)
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
gs_hash = (
'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank'
+ "/"
+ model_dir_name
+ "/hash.npy"
+ " "
+ os.path.join(model_dir, "upstream_hash.npy")
)
if os.system(gs_hash) != 0:
raise Exception("hash of the model not present in the tank.")
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
gs_download_model()
model_dir = os.path.join(WORKDIR, model_dir_name)
with open(
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir")
@@ -160,7 +182,8 @@ def download_tf_model(model_name):
model_name = model_name.replace("/", "_")
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_tf"
if not check_dir_exists(model_dir_name, frontend="tf"):
def gs_download_model():
gs_command = (
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
+ "/"
@@ -171,6 +194,27 @@ def download_tf_model(model_name):
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
if not check_dir_exists(model_dir_name, frontend="tf"):
gs_download_model()
else:
model_dir = os.path.join(WORKDIR, model_dir_name)
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
gs_hash = (
'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank'
+ "/"
+ model_dir_name
+ "/hash.npy"
+ " "
+ os.path.join(model_dir, "upstream_hash.npy")
)
if os.system(gs_hash) != 0:
raise Exception("hash of the model not present in the tank.")
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
gs_download_model()
model_dir = os.path.join(WORKDIR, model_dir_name)
with open(os.path.join(model_dir, model_name + "_tf.mlir")) as f:
mlir_file = f.read()

View File

@@ -14,5 +14,6 @@ roberta-base,hf
xlm-roberta-base,hf
microsoft/MiniLM-L12-H384-uncased,hf
funnel-transformer/small,hf
microsoft/mpnet-base,hf
facebook/convnext-tiny-224,img
google/vit-base-patch16-224,img
1 model_name model_type
14 xlm-roberta-base hf
15 microsoft/MiniLM-L12-H384-uncased hf
16 funnel-transformer/small hf
17 microsoft/mpnet-base hf
18 facebook/convnext-tiny-224 img
19 google/vit-base-patch16-224 img