Rename tank tflite/torch model dir (#219)

This commit is contained in:
Chi_Liu
2022-07-28 23:56:05 -07:00
committed by GitHub
parent ff20ddeb97
commit c1cb7bb3fd
2 changed files with 28 additions and 12 deletions

View File

@@ -35,6 +35,7 @@ home = str(Path.home())
WORKDIR = os.path.join(home, ".local/shark_tank/")
print(WORKDIR)
# Checks whether the directory and files exists.
def check_dir_exists(model_name, frontend="torch", dynamic=""):
model_dir = os.path.join(WORKDIR, model_name)
@@ -42,11 +43,18 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
# Remove the _tf keyword from end.
if frontend in ["tf", "tensorflow"]:
model_name = model_name[:-3]
elif frontend in ["tflite"]:
model_name = model_name[:-7]
elif frontend in ["torch", "pytorch"]:
model_name = model_name[:-6]
if os.path.isdir(model_dir):
if (
os.path.isfile(
os.path.join(model_dir, model_name + dynamic + ".mlir")
os.path.join(
model_dir,
model_name + dynamic + "_" + str(frontend) + ".mlir",
)
)
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
@@ -65,27 +73,28 @@ def download_torch_model(model_name, dynamic=False):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_torch"
def gs_download_model():
gs_command = (
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
+ "/"
+ model_name
+ model_dir_name
+ " "
+ WORKDIR
)
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
if not check_dir_exists(model_name, dyn_str):
if not check_dir_exists(model_dir_name, frontend="torch", dynamic=dyn_str):
gs_download_model()
else:
model_dir = os.path.join(WORKDIR, model_name)
model_dir = os.path.join(WORKDIR, model_dir_name)
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
gs_hash = (
'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank'
+ "/"
+ model_name
+ model_dir_name
+ "/hash.npy"
+ " "
+ os.path.join(model_dir, "upstream_hash.npy")
@@ -98,8 +107,10 @@ def download_torch_model(model_name, dynamic=False):
if local_hash != upstream_hash:
gs_download_model()
model_dir = os.path.join(WORKDIR, model_name)
with open(os.path.join(model_dir, model_name + dyn_str + ".mlir")) as f:
model_dir = os.path.join(WORKDIR, model_dir_name)
with open(
os.path.join(model_dir, model_name + dyn_str + "_torch.mlir")
) as f:
mlir_file = f.read()
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
@@ -115,18 +126,21 @@ def download_torch_model(model_name, dynamic=False):
def download_tflite_model(model_name, dynamic=False):
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
if not check_dir_exists(model_name, dyn_str):
model_dir_name = model_name + "_tflite"
if not check_dir_exists(
model_dir_name, frontend="tflite", dynamic=dyn_str
):
gs_command = (
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
+ "/"
+ model_name
+ model_dir_name
+ " "
+ WORKDIR
)
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
model_dir = os.path.join(WORKDIR, model_name)
model_dir = os.path.join(WORKDIR, model_dir_name)
with open(
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir")
) as f: