adds a flag to enable directory choice (#303)

individual tests will require implementation of the flag
alternatively, simply passing shark_default_sha in your
individual app's download function will allow for this behavior
This commit is contained in:
Daniel Garvey
2022-09-01 00:17:40 -05:00
committed by GitHub
parent 4ee164c66f
commit d45a496030
2 changed files with 12 additions and 5 deletions

View File

@@ -76,5 +76,10 @@ parser.add_argument(
action="store_true",
help="When enabled, pytest bench results will include ONNX benchmark results.",
)
parser.add_argument(
"--shark_prefix",
default="latest",
help="gs://shark_tank/<this_flag>/model_directories",
)
shark_args, unknown = parser.parse_known_args()

View File

@@ -29,8 +29,6 @@ input_type_to_np_dtype = {
"int8": np.int8,
}
# default hash is updated when nightly populate_sharktank_ci is successful
shark_default_sha = "latest"
# Save the model in the home local so it needn't be fetched everytime in the CI.
home = str(Path.home())
@@ -72,7 +70,9 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
# Downloads the torch model from gs://shark_tank dir.
def download_torch_model(model_name, dynamic=False):
def download_torch_model(
model_name, dynamic=False, shark_default_sha="latest"
):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
@@ -130,7 +130,9 @@ def download_torch_model(model_name, dynamic=False):
# Downloads the tflite model from gs://shark_tank dir.
def download_tflite_model(model_name, dynamic=False):
def download_tflite_model(
model_name, dynamic=False, shark_default_sha="latest"
):
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_tflite"
@@ -188,7 +190,7 @@ def download_tflite_model(model_name, dynamic=False):
return mlir_file, function_name, inputs_tuple, golden_out_tuple
def download_tf_model(model_name, tuned=None):
def download_tf_model(model_name, tuned=None, shark_default_sha="latest"):
model_name = model_name.replace("/", "_")
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_tf"