mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
adds a flag to enable directory choice (#303)
individual tests will require implementation of the flag alternatively, simply passing shark_default_sha in your individual app's download function will allow for this behavior
This commit is contained in:
@@ -76,5 +76,10 @@ parser.add_argument(
|
||||
action="store_true",
|
||||
help="When enabled, pytest bench results will include ONNX benchmark results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shark_prefix",
|
||||
default="latest",
|
||||
help="gs://shark_tank/<this_flag>/model_directories",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -29,8 +29,6 @@ input_type_to_np_dtype = {
|
||||
"int8": np.int8,
|
||||
}
|
||||
|
||||
# default hash is updated when nightly populate_sharktank_ci is successful
|
||||
shark_default_sha = "latest"
|
||||
|
||||
# Save the model in the home local so it needn't be fetched everytime in the CI.
|
||||
home = str(Path.home())
|
||||
@@ -72,7 +70,9 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
|
||||
|
||||
|
||||
# Downloads the torch model from gs://shark_tank dir.
|
||||
def download_torch_model(model_name, dynamic=False):
|
||||
def download_torch_model(
|
||||
model_name, dynamic=False, shark_default_sha="latest"
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
@@ -130,7 +130,9 @@ def download_torch_model(model_name, dynamic=False):
|
||||
|
||||
|
||||
# Downloads the tflite model from gs://shark_tank dir.
|
||||
def download_tflite_model(model_name, dynamic=False):
|
||||
def download_tflite_model(
|
||||
model_name, dynamic=False, shark_default_sha="latest"
|
||||
):
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tflite"
|
||||
@@ -188,7 +190,7 @@ def download_tflite_model(model_name, dynamic=False):
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
def download_tf_model(model_name, tuned=None):
|
||||
def download_tf_model(model_name, tuned=None, shark_default_sha="latest"):
|
||||
model_name = model_name.replace("/", "_")
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tf"
|
||||
|
||||
Reference in New Issue
Block a user