mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
298 lines
10 KiB
Python
298 lines
10 KiB
Python
# Lint as: python3
|
|
"""AMDSHARK Downloader"""
|
|
# Requirements : Put amdshark_tank in AMDSHARK directory
|
|
# /AMDSHARK
|
|
# /gen_amdshark_tank
|
|
# /tflite
|
|
# /albert_lite_base
|
|
# /...model_name...
|
|
# /tf
|
|
# /pytorch
|
|
#
|
|
#
|
|
#
|
|
|
|
import numpy as np
|
|
import os
|
|
from tqdm.std import tqdm
|
|
import sys
|
|
from pathlib import Path
|
|
from amdshark.parser import amdshark_args
|
|
from google.cloud import storage
|
|
|
|
|
|
def download_public_file(
|
|
full_gs_url, destination_folder_name, single_file=False
|
|
):
|
|
"""Downloads a public blob from the bucket."""
|
|
# bucket_name = "gs://your-bucket-name/path/to/file"
|
|
# destination_file_name = "local/path/to/file"
|
|
|
|
storage_client = storage.Client.create_anonymous_client()
|
|
bucket_name = full_gs_url.split("/")[2]
|
|
source_blob_name = None
|
|
dest_filename = None
|
|
desired_file = None
|
|
if single_file:
|
|
desired_file = full_gs_url.split("/")[-1]
|
|
source_blob_name = "/".join(full_gs_url.split("/")[3:-1])
|
|
destination_folder_name, dest_filename = os.path.split(
|
|
destination_folder_name
|
|
)
|
|
else:
|
|
source_blob_name = "/".join(full_gs_url.split("/")[3:])
|
|
bucket = storage_client.bucket(bucket_name)
|
|
blobs = bucket.list_blobs(prefix=source_blob_name)
|
|
if not os.path.exists(destination_folder_name):
|
|
os.mkdir(destination_folder_name)
|
|
for blob in blobs:
|
|
blob_name = blob.name.split("/")[-1]
|
|
if single_file:
|
|
if blob_name == desired_file:
|
|
destination_filename = os.path.join(
|
|
destination_folder_name, dest_filename
|
|
)
|
|
with open(destination_filename, "wb") as f:
|
|
with tqdm.wrapattr(
|
|
f, "write", total=blob.size
|
|
) as file_obj:
|
|
storage_client.download_blob_to_file(blob, file_obj)
|
|
else:
|
|
continue
|
|
|
|
else:
|
|
destination_filename = os.path.join(
|
|
destination_folder_name, blob_name
|
|
)
|
|
if os.path.isdir(destination_filename):
|
|
continue
|
|
with open(destination_filename, "wb") as f:
|
|
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
|
storage_client.download_blob_to_file(blob, file_obj)
|
|
|
|
|
|
input_type_to_np_dtype = {
|
|
"float32": np.float32,
|
|
"float64": np.float64,
|
|
"bool": np.bool_,
|
|
"int32": np.int32,
|
|
"int64": np.int64,
|
|
"uint8": np.uint8,
|
|
"int8": np.int8,
|
|
}
|
|
|
|
# Save the model in the home local so it needn't be fetched everytime in the CI.
|
|
home = str(Path.home())
|
|
alt_path = os.path.join(os.path.dirname(__file__), "../gen_amdshark_tank/")
|
|
custom_path = amdshark_args.local_tank_cache
|
|
|
|
if custom_path is not None:
|
|
if not os.path.exists(custom_path):
|
|
os.mkdir(custom_path)
|
|
|
|
WORKDIR = custom_path
|
|
|
|
print(f"Using {WORKDIR} as local amdshark_tank cache directory.")
|
|
|
|
elif os.path.exists(alt_path):
|
|
WORKDIR = alt_path
|
|
print(
|
|
f"Using {WORKDIR} as amdshark_tank directory. Delete this directory if you aren't working from locally generated amdshark_tank."
|
|
)
|
|
else:
|
|
WORKDIR = os.path.join(home, ".local/amdshark_tank/")
|
|
print(
|
|
f"amdshark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag"
|
|
)
|
|
os.makedirs(WORKDIR, exist_ok=True)
|
|
|
|
|
|
# Checks whether the directory and files exists.
|
|
def check_dir_exists(model_name, frontend="torch", dynamic=""):
|
|
model_dir = os.path.join(WORKDIR, model_name)
|
|
|
|
# Remove the _tf keyword from end only for non-SD models.
|
|
if not any(model in model_name for model in ["clip", "unet", "vae"]):
|
|
if frontend in ["tf", "tensorflow"]:
|
|
model_name = model_name[:-3]
|
|
elif frontend in ["tflite"]:
|
|
model_name = model_name[:-7]
|
|
elif frontend in ["torch", "pytorch"]:
|
|
model_name = model_name[:-6]
|
|
|
|
model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir"
|
|
|
|
if os.path.isdir(model_dir):
|
|
if (
|
|
os.path.isfile(os.path.join(model_dir, model_mlir_file_name))
|
|
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
|
|
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
|
|
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
|
|
and os.path.isfile(os.path.join(model_dir, "hash.npy"))
|
|
):
|
|
print(
|
|
f"""Model artifacts for {model_name} found at {WORKDIR}..."""
|
|
)
|
|
return True
|
|
return False
|
|
|
|
|
|
def _internet_connected():
|
|
import requests as req
|
|
|
|
try:
|
|
req.get("http://1.1.1.1")
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
|
|
def get_git_revision_short_hash() -> str:
|
|
import subprocess
|
|
|
|
if amdshark_args.amdshark_prefix is not None:
|
|
prefix_kw = amdshark_args.amdshark_prefix
|
|
else:
|
|
import json
|
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
src = os.path.join(dir_path, "..", "tank_version.json")
|
|
with open(src, "r") as f:
|
|
data = json.loads(f.read())
|
|
prefix_kw = data["version"]
|
|
print(f"Checking for updates from gs://amdshark_tank/{prefix_kw}")
|
|
return prefix_kw
|
|
|
|
|
|
def get_amdsharktank_prefix():
|
|
tank_prefix = ""
|
|
if not _internet_connected():
|
|
print(
|
|
"No internet connection. Using the model already present in the tank."
|
|
)
|
|
tank_prefix = "none"
|
|
else:
|
|
desired_prefix = get_git_revision_short_hash()
|
|
storage_client_a = storage.Client.create_anonymous_client()
|
|
base_bucket_name = "amdshark_tank"
|
|
base_bucket = storage_client_a.bucket(base_bucket_name)
|
|
dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}")
|
|
for blob in dir_blobs:
|
|
dir_blob_name = blob.name.split("/")
|
|
if desired_prefix in dir_blob_name[0]:
|
|
tank_prefix = dir_blob_name[0]
|
|
break
|
|
else:
|
|
continue
|
|
if tank_prefix == "":
|
|
print(
|
|
f"amdshark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly."
|
|
)
|
|
tank_prefix = "nightly"
|
|
return tank_prefix
|
|
|
|
|
|
# Downloads the torch model from gs://amdshark_tank dir.
|
|
def download_model(
|
|
model_name,
|
|
dynamic=False,
|
|
tank_url=None,
|
|
frontend=None,
|
|
tuned=None,
|
|
import_args={"batch_size": 1},
|
|
):
|
|
model_name = model_name.replace("/", "_")
|
|
dyn_str = "_dynamic" if dynamic else ""
|
|
os.makedirs(WORKDIR, exist_ok=True)
|
|
amdshark_args.amdshark_prefix = get_amdsharktank_prefix()
|
|
if import_args["batch_size"] and import_args["batch_size"] != 1:
|
|
model_dir_name = (
|
|
model_name
|
|
+ "_"
|
|
+ frontend
|
|
+ "_BS"
|
|
+ str(import_args["batch_size"])
|
|
)
|
|
elif any(model in model_name for model in ["clip", "unet", "vae"]):
|
|
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
|
|
model_dir_name = model_name
|
|
else:
|
|
model_dir_name = model_name + "_" + frontend
|
|
model_dir = os.path.join(WORKDIR, model_dir_name)
|
|
|
|
if not tank_url:
|
|
tank_url = "gs://amdshark_tank/" + amdshark_args.amdshark_prefix
|
|
|
|
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
|
if not check_dir_exists(
|
|
model_dir_name, frontend=frontend, dynamic=dyn_str
|
|
):
|
|
print(
|
|
f"Downloading artifacts for model {model_name} from: {full_gs_url}"
|
|
)
|
|
download_public_file(full_gs_url, model_dir)
|
|
|
|
elif amdshark_args.force_update_tank == True:
|
|
print(
|
|
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
|
|
)
|
|
download_public_file(full_gs_url, model_dir)
|
|
else:
|
|
if not _internet_connected():
|
|
print(
|
|
"No internet connection. Using the model already present in the tank."
|
|
)
|
|
else:
|
|
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
|
gs_hash_url = (
|
|
tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy"
|
|
)
|
|
download_public_file(
|
|
gs_hash_url,
|
|
os.path.join(model_dir, "upstream_hash.npy"),
|
|
single_file=True,
|
|
)
|
|
try:
|
|
upstream_hash = str(
|
|
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
|
)
|
|
except FileNotFoundError:
|
|
print(f"Model artifact hash not found at {model_dir}.")
|
|
upstream_hash = None
|
|
if local_hash != upstream_hash and amdshark_args.update_tank == True:
|
|
print(f"Updating artifacts for model {model_name}...")
|
|
download_public_file(full_gs_url, model_dir)
|
|
|
|
elif local_hash != upstream_hash:
|
|
print(
|
|
"Hash does not match upstream in gs://amdshark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
|
|
)
|
|
else:
|
|
print(
|
|
"Local and upstream hashes match. Using cached model artifacts."
|
|
)
|
|
|
|
model_dir = os.path.join(WORKDIR, model_dir_name)
|
|
tuned_str = "" if tuned is None else "_" + tuned
|
|
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
|
mlir_filename = os.path.join(model_dir, model_name + suffix)
|
|
print(
|
|
f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
|
|
)
|
|
if not os.path.exists(mlir_filename):
|
|
from tank.generate_amdsharktank import gen_amdshark_files
|
|
|
|
print(
|
|
"The model data was not found. Trying to generate artifacts locally."
|
|
)
|
|
gen_amdshark_files(model_name, frontend, WORKDIR, import_args)
|
|
|
|
assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
|
|
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
|
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
|
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
|
|
|
inputs_tuple = tuple([inputs[key] for key in inputs])
|
|
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
|
return mlir_filename, function_name, inputs_tuple, golden_out_tuple
|