mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Fix local artifact recognition and usage by SHARK downloader. (#286)
* Fix local artifact recognition and usage by SHARK downloader. * Update generate_sharktank.py * Update generate_sharktank.py
This commit is contained in:
@@ -2,10 +2,11 @@
|
||||
"""SHARK Tank"""
|
||||
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
|
||||
# will generate local shark tank folder like this:
|
||||
# /SHARK
|
||||
# /gen_shark_tank
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
# HOME
|
||||
# /.local
|
||||
# /shark_tank
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
#
|
||||
|
||||
import os
|
||||
@@ -16,6 +17,7 @@ import tensorflow as tf
|
||||
import subprocess as sp
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
@@ -28,7 +30,8 @@ except:
|
||||
pass
|
||||
|
||||
# All generated models and metadata will be saved under this directory.
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
home = str(Path.home())
|
||||
WORKDIR = os.path.join(home, ".local/shark_tank/")
|
||||
|
||||
|
||||
def create_hash(file_name):
|
||||
@@ -237,5 +240,5 @@ if __name__ == "__main__":
|
||||
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
|
||||
print("uploading files to gs://shark_tank/" + git_hash)
|
||||
os.system(
|
||||
"gsutil cp -r ./gen_shark_tank/* gs://shark_tank/" + git_hash
|
||||
"gsutil cp -r ~/.local/shark_tank/* gs://shark_tank/" + git_hash
|
||||
)
|
||||
|
||||
@@ -110,7 +110,9 @@ def download_torch_model(model_name, dynamic=False):
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
gs_download_model()
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(
|
||||
@@ -167,7 +169,9 @@ def download_tflite_model(model_name, dynamic=False):
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
gs_download_model()
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(
|
||||
@@ -221,7 +225,9 @@ def download_tf_model(model_name):
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
gs_download_model()
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(os.path.join(model_dir, model_name + "_tf.mlir")) as f:
|
||||
|
||||
Reference in New Issue
Block a user