Use the downloaded folder instead of re-downloading.

shark_tank models.
This commit is contained in:
Prashant Kumar
2022-07-18 13:43:03 +05:30
parent 54a642e76a
commit 9105f5d54e

View File

@@ -28,16 +28,33 @@ input_type_to_np_dtype = {
}
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
print(WORKDIR)
# Checks whether the directory and files exists.
def check_dir_exists(model_name):
model_dir = os.path.join(WORKDIR, model_name)
if os.path.isdir(model_dir):
if (
os.path.isfile(os.path.join(model_dir, model_name + ".mlir"))
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
):
return True
return False
# Downloads the torch model from gs://shark_tank dir.
def download_torch_model(model_name):
model_name = model_name.replace("/", "_")
os.makedirs(WORKDIR, exist_ok=True)
gs_command = (
"gsutil cp -r gs://shark_tank" + "/" + model_name + " " + WORKDIR
)
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
if not check_dir_exists(model_name):
gs_command = (
"gsutil cp -r gs://shark_tank" + "/" + model_name + " " + WORKDIR
)
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
model_dir = os.path.join(WORKDIR, model_name)
with open(os.path.join(model_dir, model_name + ".mlir")) as f:
mlir_file = f.read()