From 2191fc8952015b8a9989b4f028946a2f76225b86 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Mon, 3 Apr 2023 10:24:21 -0500 Subject: [PATCH] Separate pytest benchmark modes and fix model updates for SHARK downloader / pytest. (#1264) * Only xfail windows models in CI * downloader: make model updates more robust. * Separate baseline and native benchmarks in pytest. * Fix native benchmarks * Fix torchvision model utils. --- .github/workflows/test-models.yml | 8 +-- conftest.py | 26 ++++++--- pytest.ini | 2 +- shark/parser.py | 4 +- shark/shark_benchmark_runner.py | 37 +++++++++---- shark/shark_downloader.py | 87 +++++++++++++++++++++++++------ shark/shark_importer.py | 17 ++++-- tank/all_models.csv | 6 +-- tank/generate_sharktank.py | 7 ++- tank/model_utils.py | 80 ++++++++++++++-------------- tank/model_utils_tf.py | 3 ++ tank/test_models.py | 82 +++++++++++++++++++---------- 12 files changed, 244 insertions(+), 115 deletions(-) diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml index 8dd2e8fa..c458d656 100644 --- a/.github/workflows/test-models.yml +++ b/.github/workflows/test-models.yml @@ -112,7 +112,7 @@ jobs: cd $GITHUB_WORKSPACE PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh source shark.venv/bin/activate - pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu + pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv @@ -122,7 +122,7 @@ jobs: cd $GITHUB_WORKSPACE PYTHON=python${{ matrix.python-version }} ./setup_venv.sh source shark.venv/bin/activate - pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda + pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv # Disabled due to black image bug @@ -145,14 +145,14 @@ jobs: cd $GITHUB_WORKSPACE PYTHON=python${{ matrix.python-version }} ./setup_venv.sh source shark.venv/bin/activate - pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan + pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan python build_tools/stable_diffusion_testing.py --device=vulkan - name: Validate Vulkan Models (Windows) if: matrix.suite == 'vulkan' && matrix.os == '7950x' run: | ./setup_venv.ps1 - pytest -k vulkan -s + pytest -k vulkan -s --ci - name: Validate Stable Diffusion Models (Windows) if: matrix.suite == 'vulkan' && matrix.os == '7950x' diff --git a/conftest.py b/conftest.py index 89cf39aa..95442d76 100644 --- a/conftest.py +++ b/conftest.py @@ -2,9 +2,11 @@ def pytest_addoption(parser): # Attaches SHARK command-line arguments to the pytest machinery. parser.addoption( "--benchmark", - action="store_true", - default="False", - help="Pass option to benchmark and write results.csv", + action="store", + type=str, + default=None, + choices=("baseline", "native", "all"), + help="Benchmarks specified engine(s) and writes bench_results.csv.", ) parser.addoption( "--onnx_bench", @@ -40,7 +42,13 @@ def pytest_addoption(parser): "--update_tank", action="store_true", default="False", - help="Update local shark tank with latest artifacts.", + help="Update local shark tank with latest artifacts if model artifact hash mismatched.", + ) + parser.addoption( + "--force_update_tank", + action="store_true", + default="False", + help="Force-update local shark tank with artifacts from specified shark_tank URL (defaults to nightly).", ) parser.addoption( "--ci_sha", @@ -51,15 +59,21 @@ def pytest_addoption(parser): parser.addoption( "--local_tank_cache", action="store", - default="", + default=None, help="Specify the directory in which all downloaded shark_tank artifacts will be cached.", ) parser.addoption( "--tank_url", type=str, - default="gs://shark_tank/latest", + default="gs://shark_tank/nightly", help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest", ) + parser.addoption( + "--tank_prefix", + type=str, + default="nightly", + help="Prefix to gs://shark_tank/ model directories from which to download SHARK tank artifacts. Default is 'latest'.", + ) parser.addoption( "--benchmark_dispatches", default=None, diff --git a/pytest.ini b/pytest.ini index 3634f600..11f57888 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ [pytest] -addopts = --verbose -p no:warnings +addopts = --verbose -s -p no:warnings norecursedirs = inference tank/tflite examples benchmarks shark diff --git a/shark/parser.py b/shark/parser.py index d08ebdb9..47cc1a86 100644 --- a/shark/parser.py +++ b/shark/parser.py @@ -14,8 +14,10 @@ import argparse import os +import subprocess parser = argparse.ArgumentParser(description="SHARK runner.") + parser.add_argument( "--device", type=str, @@ -54,7 +56,7 @@ parser.add_argument( ) parser.add_argument( "--shark_prefix", - default="latest", + default=None, help="gs://shark_tank//model_directories", ) parser.add_argument( diff --git a/shark/shark_benchmark_runner.py b/shark/shark_benchmark_runner.py index c8f63b8e..81ba410d 100644 --- a/shark/shark_benchmark_runner.py +++ b/shark/shark_benchmark_runner.py @@ -118,10 +118,7 @@ class SharkBenchmarkRunner(SharkRunner): if self.device == "cuda": torch.set_default_tensor_type(torch.cuda.FloatTensor) if self.enable_tf32: - print( - "Currently disabled TensorFloat32 calculations in pytorch benchmarks." - ) - # torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True else: torch.set_default_tensor_type(torch.FloatTensor) torch_device = torch.device( @@ -133,12 +130,12 @@ class SharkBenchmarkRunner(SharkRunner): input.to(torch_device) # TODO: re-enable as soon as pytorch CUDA context issues are resolved - # try: - # frontend_model = torch.compile( - # frontend_model, mode="max-autotune", backend="inductor" - # ) - # except RuntimeError: - # frontend_model = HFmodel.model + try: + frontend_model = torch.compile( + frontend_model, mode="max-autotune", backend="inductor" + ) + except RuntimeError: + frontend_model = HFmodel.model for i in range(shark_args.num_warmup_iterations): frontend_model.forward(input) @@ -152,12 +149,18 @@ class SharkBenchmarkRunner(SharkRunner): if self.device == "cuda": stats = torch.cuda.memory_stats() device_peak_b = stats["allocated_bytes.all.peak"] + frontend_model.to(torch.device("cpu")) + input.to(torch.device("cpu")) + torch.cuda.empty_cache() else: device_peak_b = None print( f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}" ) + if self.device == "cuda": + # Set device to CPU so we don't run into segfaults exiting pytest subprocesses. + torch_device = torch.device("cpu") return [ f"{shark_args.num_iterations/(end-begin)}", f"{((end-begin)/shark_args.num_iterations)*1000}", @@ -166,6 +169,9 @@ class SharkBenchmarkRunner(SharkRunner): ] def benchmark_tf(self, modelname): + import os + + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import tensorflow as tf visible_default = tf.config.list_physical_devices("GPU") @@ -354,9 +360,11 @@ for currently supported models. Exiting benchmark ONNX." device_str, frontend, import_args, + mode="native", ): self.setup_cl(inputs) self.import_args = import_args + self.mode = mode field_names = [ "model", "batch_size", @@ -379,7 +387,13 @@ for currently supported models. Exiting benchmark ONNX." "measured_device_memory_mb", ] # "frontend" must be the first element. - engines = ["frontend", "shark_python", "shark_iree_c"] + if self.mode == "native": + engines = ["shark_python", "shark_iree_c"] + if self.mode == "baseline": + engines = ["frontend"] + if self.mode == "all": + engines = ["frontend", "shark_python", "shark_iree_c"] + if shark_args.onnx_bench == True: engines.append("onnxruntime") @@ -407,6 +421,7 @@ for currently supported models. Exiting benchmark ONNX." for e in engines: engine_result = {} + self.frontend_result = None if e == "frontend": engine_result["engine"] = frontend if check_requirements(frontend): diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index 0164b39d..5c99f021 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -127,16 +127,73 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""): and os.path.isfile(os.path.join(model_dir, "golden_out.npz")) and os.path.isfile(os.path.join(model_dir, "hash.npy")) ): - print(f"""Using cached models from {WORKDIR}...""") + print( + f"""Model artifacts for {model_name} found at {WORKDIR}...""" + ) return True return False +def _internet_connected(): + import requests as req + + try: + req.get("http://1.1.1.1") + return True + except: + return False + + +def get_git_revision_short_hash() -> str: + import subprocess + + if shark_args.shark_prefix is not None: + prefix_kw = shark_args.shark_prefix + else: + prefix_kw = ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .decode("ascii") + .strip() + ) + return prefix_kw + + +def get_sharktank_prefix(): + tank_prefix = "" + if not _internet_connected(): + print( + "No internet connection. Using the model already present in the tank." + ) + tank_prefix = "none" + else: + desired_prefix = get_git_revision_short_hash() + storage_client_a = storage.Client.create_anonymous_client() + base_bucket_name = "shark_tank" + base_bucket = storage_client_a.bucket(base_bucket_name) + dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}") + for blob in dir_blobs: + dir_blob_name = blob.name.split("/") + if desired_prefix in dir_blob_name[0]: + tank_prefix = dir_blob_name[0] + break + else: + continue + if tank_prefix == "": + print( + f"shark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly." + ) + tank_prefix = "nightly" + return tank_prefix + + +shark_args.shark_prefix = get_sharktank_prefix() + + # Downloads the torch model from gs://shark_tank dir. def download_model( model_name, dynamic=False, - tank_url="gs://shark_tank/latest", + tank_url=None, frontend=None, tuned=None, import_args={"batch_size": "1"}, @@ -155,15 +212,19 @@ def download_model( else: model_dir_name = model_name + "_" + frontend model_dir = os.path.join(WORKDIR, model_dir_name) - full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name + if not tank_url: + tank_url = "gs://shark_tank/" + shark_args.shark_prefix + + full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name if not check_dir_exists( model_dir_name, frontend=frontend, dynamic=dyn_str ): print( - f"Force-updating artifacts for model {model_name} from: {full_gs_url}" + f"Downloading artifacts for model {model_name} from: {full_gs_url}" ) download_public_file(full_gs_url, model_dir) + elif shark_args.force_update_tank == True: print( f"Force-updating artifacts for model {model_name} from: {full_gs_url}" @@ -189,6 +250,7 @@ def download_model( np.load(os.path.join(model_dir, "upstream_hash.npy")) ) except FileNotFoundError: + print(f"Model artifact hash not found at {model_dir}.") upstream_hash = None if local_hash != upstream_hash and shark_args.update_tank == True: print(f"Updating artifacts for model {model_name}...") @@ -196,14 +258,17 @@ def download_model( elif local_hash != upstream_hash: print( - "Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank." + "Hash does not match upstream in gs://shark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank." + ) + else: + print( + "Local and upstream hashes match. Using cached model artifacts." ) model_dir = os.path.join(WORKDIR, model_dir_name) tuned_str = "" if tuned is None else "_" + tuned suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir" filename = os.path.join(model_dir, model_name + suffix) - if not os.path.exists(filename): from tank.generate_sharktank import gen_shark_files @@ -222,13 +287,3 @@ def download_model( inputs_tuple = tuple([inputs[key] for key in inputs]) golden_out_tuple = tuple([golden_out[key] for key in golden_out]) return mlir_file, function_name, inputs_tuple, golden_out_tuple - - -def _internet_connected(): - import requests as req - - try: - req.get("http://1.1.1.1") - return True - except: - return False diff --git a/shark/shark_importer.py b/shark/shark_importer.py index 89dd4042..f9c74bb5 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -9,8 +9,8 @@ import hashlib def create_hash(file_name): with open(file_name, "rb") as f: - file_hash = hashlib.blake2b() - while chunk := f.read(2**20): + file_hash = hashlib.blake2b(digest_size=64) + while chunk := f.read(2**10): file_hash.update(chunk) return file_hash.hexdigest() @@ -165,8 +165,17 @@ class SharkImporter: if self.frontend == "torch": with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file: mlir_file.write(mlir_data) - mlir_hash = create_hash(os.path.join(dir, model_name_mlir)) - np.save(os.path.join(dir, "hash"), np.array(mlir_hash)) + hash_gen_attempts = 2 + for i in range(hash_gen_attempts): + try: + mlir_hash = create_hash(os.path.join(dir, model_name_mlir)) + except FileNotFoundError as err: + if i < hash_gen_attempts: + continue + else: + raise err + + np.save(os.path.join(dir, "hash"), np.array(mlir_hash)) return def import_debug( diff --git a/tank/all_models.csv b/tank/all_models.csv index 5e479bd4..ab8a04e7 100644 --- a/tank/all_models.csv +++ b/tank/all_models.csv @@ -36,9 +36,9 @@ wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,Fal efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos" mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos" efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243","" -efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,False,False,"Torchvision imports issue","" -efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"","" -efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"","" +efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos" +efficientnet_b0,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","" +efficientnet_b7,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos" gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"","" t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","" t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","" diff --git a/tank/generate_sharktank.py b/tank/generate_sharktank.py index 29906008..0228e5e4 100644 --- a/tank/generate_sharktank.py +++ b/tank/generate_sharktank.py @@ -26,8 +26,8 @@ from apps.stable_diffusion.src.utils.stable_args import ( def create_hash(file_name): with open(file_name, "rb") as f: - file_hash = hashlib.blake2b() - while chunk := f.read(2**20): + file_hash = hashlib.blake2b(digest_size=64) + while chunk := f.read(2**10): file_hash.update(chunk) return file_hash.hexdigest() @@ -141,6 +141,9 @@ def save_tf_model(tf_model_list, local_tank_cache, import_args): get_TFhf_model, get_tfhf_seq2seq_model, ) + import os + + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import tensorflow as tf visible_default = tf.config.list_physical_devices("GPU") diff --git a/tank/model_utils.py b/tank/model_utils.py index 0010e4fa..b9567803 100644 --- a/tank/model_utils.py +++ b/tank/model_utils.py @@ -195,47 +195,47 @@ def get_vision_model(torch_model, import_args): import torchvision.models as models default_image_size = (224, 224) + modelname = torch_model + if modelname == "alexnet": + torch_model = models.alexnet(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "resnet18": + torch_model = models.resnet18(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "resnet50": + torch_model = models.resnet50(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "resnet50_fp16": + torch_model = models.resnet50(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "resnet50_fp16": + torch_model = models.resnet50(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "resnet101": + torch_model = models.resnet101(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "squeezenet1_0": + torch_model = models.squeezenet1_0(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "wide_resnet50_2": + torch_model = models.wide_resnet50_2(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "mobilenet_v3_small": + torch_model = models.mobilenet_v3_small(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "mnasnet1_0": + torch_model = models.mnasnet1_0(weights="DEFAULT") + input_image_size = default_image_size + if modelname == "efficientnet_b0": + torch_model = models.efficientnet_b0(weights="DEFAULT") + input_image_size = (224, 224) + if modelname == "efficientnet_b7": + torch_model = models.efficientnet_b7(weights="DEFAULT") + input_image_size = (600, 600) - vision_models_dict = { - "alexnet": (models.alexnet(weights="DEFAULT"), default_image_size), - "resnet18": (models.resnet18(weights="DEFAULT"), default_image_size), - "resnet50": (models.resnet50(weights="DEFAULT"), default_image_size), - "resnet50_fp16": ( - models.resnet50(weights="DEFAULT"), - default_image_size, - ), - "resnet101": (models.resnet101(weights="DEFAULT"), default_image_size), - "squeezenet1_0": ( - models.squeezenet1_0(weights="DEFAULT"), - default_image_size, - ), - "wide_resnet50_2": ( - models.wide_resnet50_2(weights="DEFAULT"), - default_image_size, - ), - "mobilenet_v3_small": ( - models.mobilenet_v3_small(weights="DEFAULT"), - default_image_size, - ), - "mnasnet1_0": ( - models.mnasnet1_0(weights="DEFAULT"), - default_image_size, - ), - # EfficientNet input image size varies on the size of the model. - "efficientnet_b0": ( - models.efficientnet_b0(weights="DEFAULT"), - (224, 224), - ), - "efficientnet_b7": ( - models.efficientnet_b7(weights="DEFAULT"), - (600, 600), - ), - } - if isinstance(torch_model, str): - fp16_model = None - if "fp16" in torch_model: - fp16_model = True - torch_model, input_image_size = vision_models_dict[torch_model] + fp16_model = False + if "fp16" in modelname: + fp16_model = True model = VisionModule(torch_model) test_input = torch.randn( int(import_args["batch_size"]), 3, *input_image_size diff --git a/tank/model_utils_tf.py b/tank/model_utils_tf.py index b1b34a3c..78f3dda6 100644 --- a/tank/model_utils_tf.py +++ b/tank/model_utils_tf.py @@ -1,3 +1,6 @@ +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import tensorflow as tf import numpy as np diff --git a/tank/test_models.py b/tank/test_models.py index 02a0b5d9..ec16c46d 100644 --- a/tank/test_models.py +++ b/tank/test_models.py @@ -4,11 +4,8 @@ from shark.iree_utils._common import ( get_supported_device_list, ) from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag -from parameterized import parameterized -from shark.shark_downloader import download_model -from shark.shark_inference import SharkInference from shark.parser import shark_args -from tank.generate_sharktank import NoImportException +from parameterized import parameterized import iree.compiler as ireec import pytest import unittest @@ -16,8 +13,8 @@ import numpy as np import csv import tempfile import os +import sys import shutil -import multiprocessing def load_csv_and_convert(filename, gen=False): @@ -140,12 +137,12 @@ class SharkModuleTester: self.config = config def create_and_check_module(self, dynamic, device): - import_config = { - "batch_size": self.batch_size, - } + shark_args.update_tank = self.update_tank + shark_args.force_update_tank = self.force_update_tank + shark_args.shark_prefix = self.shark_tank_prefix shark_args.local_tank_cache = self.local_tank_cache - shark_args.force_update_tank = self.update_tank shark_args.dispatch_benchmarks = self.benchmark_dispatches + if self.benchmark_dispatches is not None: _m = self.config["model_name"].split("/") _m.extend([self.config["framework"], str(dynamic), device]) @@ -169,12 +166,19 @@ class SharkModuleTester: if "winograd" in self.config["flags"]: shark_args.use_winograd = True + import_config = { + "batch_size": self.batch_size, + } + + from shark.shark_downloader import download_model + from shark.shark_inference import SharkInference + from tank.generate_sharktank import NoImportException + dl_gen_attempts = 2 for i in range(dl_gen_attempts): try: model, func_name, inputs, golden_out = download_model( self.config["model_name"], - tank_url=self.tank_url, frontend=self.config["framework"], import_args=import_config, ) @@ -190,11 +194,12 @@ class SharkModuleTester: "Generating OTF may require exiting the subprocess for files to be available." ) break + is_bench = True if self.benchmark is not None else False shark_module = SharkInference( model, device=device, mlir_dialect=self.config["dialect"], - is_benchmark=self.benchmark, + is_benchmark=is_bench, ) try: @@ -208,6 +213,10 @@ class SharkModuleTester: result = shark_module(func_name, inputs) golden_out, result = self.postprocess_outputs(golden_out, result) + if self.tf32 == "true": + print("Validating with relaxed tolerances.") + atol = 1e-02 + rtol = 1e-03 try: np.testing.assert_allclose( golden_out, @@ -220,19 +229,25 @@ class SharkModuleTester: self.save_reproducers() if self.ci == True: self.upload_repro() - if self.benchmark == True: - self.benchmark_module(shark_module, inputs, dynamic, device) + if self.benchmark is not None: + self.benchmark_module( + shark_module, inputs, dynamic, device, mode=self.benchmark + ) print(msg) pytest.xfail( reason=f"Numerics Mismatch: Use -s flag to print stderr during pytests." ) - if self.benchmark == True: - self.benchmark_module(shark_module, inputs, dynamic, device) + if self.benchmark is not None: + self.benchmark_module( + shark_module, inputs, dynamic, device, mode=self.benchmark + ) if self.save_repro == True: self.save_reproducers() - def benchmark_module(self, shark_module, inputs, dynamic, device): + def benchmark_module( + self, shark_module, inputs, dynamic, device, mode="native" + ): model_config = { "batch_size": self.batch_size, } @@ -248,6 +263,7 @@ class SharkModuleTester: device, self.config["framework"], import_args=model_config, + mode=mode, ) def save_reproducers(self): @@ -319,7 +335,12 @@ class SharkModuleTest(unittest.TestCase): self.module_tester.update_tank = self.pytestconfig.getoption( "update_tank" ) - self.module_tester.tank_url = self.pytestconfig.getoption("tank_url") + self.module_tester.force_update_tank = self.pytestconfig.getoption( + "force_update_tank" + ) + self.module_tester.shark_tank_prefix = self.pytestconfig.getoption( + "tank_prefix" + ) self.module_tester.benchmark_dispatches = self.pytestconfig.getoption( "benchmark_dispatches" ) @@ -336,19 +357,26 @@ class SharkModuleTest(unittest.TestCase): if config["xfail_vkm"] == "True" and device in ["metal", "vulkan"]: pytest.xfail(reason=config["xfail_reason"]) - if os.name == "nt" and "enabled_windows" not in config["xfail_other"]: + if ( + self.pytestconfig.getoption("ci") == True + and os.name == "nt" + and "enabled_windows" not in config["xfail_other"] + ): pytest.xfail(reason="this model skipped on windows") # Special cases that need to be marked. - if "macos" in config["xfail_other"] and device in [ - "metal", - "vulkan", - ]: - if get_vulkan_triple_flag() is not None: - if "m1-moltenvk-macos" in get_vulkan_triple_flag(): - pytest.xfail( - reason="conv-related issue on MacStudio, returns VK_ERROR_DEVICE_LOST." - ) + if ( + "macos" in config["xfail_other"] + and device + in [ + "metal", + "vulkan", + ] + and sys.platform == "darwin" + ): + pytest.skip( + reason="conv-related issue on MacStudio, returns VK_ERROR_DEVICE_LOST." + ) if ( config["model_name"] in [