mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
Add hash of mlir for tf/tflite (#225)
This commit is contained in:
@@ -124,6 +124,10 @@ def save_tf_model(tf_model_list):
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
|
||||
)
|
||||
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list):
|
||||
@@ -161,6 +165,16 @@ def save_tflite_model(tflite_model_list):
|
||||
model_name=tflite_model_name,
|
||||
func_name="main",
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
tflite_model_name_dir,
|
||||
tflite_model_name + "_tflite" + ".mlir",
|
||||
)
|
||||
)
|
||||
np.save(
|
||||
os.path.join(tflite_model_name_dir, "hash"),
|
||||
np.array(mlir_hash),
|
||||
)
|
||||
|
||||
|
||||
# Validates whether the file is present or not.
|
||||
|
||||
@@ -128,9 +128,8 @@ def download_tflite_model(model_name, dynamic=False):
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tflite"
|
||||
if not check_dir_exists(
|
||||
model_dir_name, frontend="tflite", dynamic=dyn_str
|
||||
):
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
|
||||
+ "/"
|
||||
@@ -141,6 +140,29 @@ def download_tflite_model(model_name, dynamic=False):
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
if not check_dir_exists(
|
||||
model_dir_name, frontend="tflite", dynamic=dyn_str
|
||||
):
|
||||
gs_download_model()
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank'
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
gs_download_model()
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(
|
||||
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir")
|
||||
@@ -160,7 +182,8 @@ def download_tf_model(model_name):
|
||||
model_name = model_name.replace("/", "_")
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tf"
|
||||
if not check_dir_exists(model_dir_name, frontend="tf"):
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
|
||||
+ "/"
|
||||
@@ -171,6 +194,27 @@ def download_tf_model(model_name):
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
if not check_dir_exists(model_dir_name, frontend="tf"):
|
||||
gs_download_model()
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank'
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
gs_download_model()
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(os.path.join(model_dir, model_name + "_tf.mlir")) as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
@@ -14,5 +14,6 @@ roberta-base,hf
|
||||
xlm-roberta-base,hf
|
||||
microsoft/MiniLM-L12-H384-uncased,hf
|
||||
funnel-transformer/small,hf
|
||||
microsoft/mpnet-base,hf
|
||||
facebook/convnext-tiny-224,img
|
||||
google/vit-base-patch16-224,img
|
||||
|
||||
|
Reference in New Issue
Block a user