mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Use the downloaded folder instead of re-downloading.
shark_tank models.
This commit is contained in:
@@ -28,16 +28,33 @@ input_type_to_np_dtype = {
|
||||
}
|
||||
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
print(WORKDIR)
|
||||
|
||||
# Checks whether the directory and files exists.
|
||||
def check_dir_exists(model_name):
|
||||
model_dir = os.path.join(WORKDIR, model_name)
|
||||
if os.path.isdir(model_dir):
|
||||
if (
|
||||
os.path.isfile(os.path.join(model_dir, model_name + ".mlir"))
|
||||
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
|
||||
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
|
||||
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Downloads the torch model from gs://shark_tank dir.
|
||||
def download_torch_model(model_name):
|
||||
model_name = model_name.replace("/", "_")
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
gs_command = (
|
||||
"gsutil cp -r gs://shark_tank" + "/" + model_name + " " + WORKDIR
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
if not check_dir_exists(model_name):
|
||||
gs_command = (
|
||||
"gsutil cp -r gs://shark_tank" + "/" + model_name + " " + WORKDIR
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_name)
|
||||
with open(os.path.join(model_dir, model_name + ".mlir")) as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
Reference in New Issue
Block a user