Fix local artifact recognition and usage by SHARK downloader. (#286)

* Fix local artifact recognition and usage by SHARK downloader.

* Update generate_sharktank.py

* Update generate_sharktank.py
This commit is contained in:
Ean Garvey
2022-08-24 14:37:16 -05:00
committed by GitHub
parent f79a6bf5aa
commit 14857770dc
2 changed files with 18 additions and 9 deletions

View File

@@ -2,10 +2,11 @@
"""SHARK Tank"""
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
# will generate local shark tank folder like this:
# /SHARK
# /gen_shark_tank
# /albert_lite_base
# /...model_name...
# HOME
# /.local
# /shark_tank
# /albert_lite_base
# /...model_name...
#
import os
@@ -16,6 +17,7 @@ import tensorflow as tf
import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
visible_default = tf.config.list_physical_devices("GPU")
try:
@@ -28,7 +30,8 @@ except:
pass
# All generated models and metadata will be saved under this directory.
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
home = str(Path.home())
WORKDIR = os.path.join(home, ".local/shark_tank/")
def create_hash(file_name):
@@ -237,5 +240,5 @@ if __name__ == "__main__":
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
print("uploading files to gs://shark_tank/" + git_hash)
os.system(
"gsutil cp -r ./gen_shark_tank/* gs://shark_tank/" + git_hash
"gsutil cp -r ~/.local/shark_tank/* gs://shark_tank/" + git_hash
)

View File

@@ -110,7 +110,9 @@ def download_torch_model(model_name, dynamic=False):
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
gs_download_model()
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
with open(
@@ -167,7 +169,9 @@ def download_tflite_model(model_name, dynamic=False):
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
gs_download_model()
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
with open(
@@ -221,7 +225,9 @@ def download_tf_model(model_name):
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
gs_download_model()
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
with open(os.path.join(model_dir, model_name + "_tf.mlir")) as f: