mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
Separate pytest benchmark modes and fix model updates for SHARK downloader / pytest. (#1264)
* Only xfail windows models in CI * downloader: make model updates more robust. * Separate baseline and native benchmarks in pytest. * Fix native benchmarks * Fix torchvision model utils.
This commit is contained in:
8
.github/workflows/test-models.yml
vendored
8
.github/workflows/test-models.yml
vendored
@@ -112,7 +112,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
|
||||
@@ -122,7 +122,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
# Disabled due to black image bug
|
||||
@@ -145,14 +145,14 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
- name: Validate Vulkan Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
pytest -k vulkan -s
|
||||
pytest -k vulkan -s --ci
|
||||
|
||||
- name: Validate Stable Diffusion Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
|
||||
26
conftest.py
26
conftest.py
@@ -2,9 +2,11 @@ def pytest_addoption(parser):
|
||||
# Attaches SHARK command-line arguments to the pytest machinery.
|
||||
parser.addoption(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to benchmark and write results.csv",
|
||||
action="store",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=("baseline", "native", "all"),
|
||||
help="Benchmarks specified engine(s) and writes bench_results.csv.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--onnx_bench",
|
||||
@@ -40,7 +42,13 @@ def pytest_addoption(parser):
|
||||
"--update_tank",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Update local shark tank with latest artifacts.",
|
||||
help="Update local shark tank with latest artifacts if model artifact hash mismatched.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--force_update_tank",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Force-update local shark tank with artifacts from specified shark_tank URL (defaults to nightly).",
|
||||
)
|
||||
parser.addoption(
|
||||
"--ci_sha",
|
||||
@@ -51,15 +59,21 @@ def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--local_tank_cache",
|
||||
action="store",
|
||||
default="",
|
||||
default=None,
|
||||
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--tank_url",
|
||||
type=str,
|
||||
default="gs://shark_tank/latest",
|
||||
default="gs://shark_tank/nightly",
|
||||
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
|
||||
)
|
||||
parser.addoption(
|
||||
"--tank_prefix",
|
||||
type=str,
|
||||
default="nightly",
|
||||
help="Prefix to gs://shark_tank/ model directories from which to download SHARK tank artifacts. Default is 'latest'.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--benchmark_dispatches",
|
||||
default=None,
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[pytest]
|
||||
addopts = --verbose -p no:warnings
|
||||
addopts = --verbose -s -p no:warnings
|
||||
norecursedirs = inference tank/tflite examples benchmarks shark
|
||||
|
||||
@@ -14,8 +14,10 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="SHARK runner.")
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
@@ -54,7 +56,7 @@ parser.add_argument(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shark_prefix",
|
||||
default="latest",
|
||||
default=None,
|
||||
help="gs://shark_tank/<this_flag>/model_directories",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -118,10 +118,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
if self.enable_tf32:
|
||||
print(
|
||||
"Currently disabled TensorFloat32 calculations in pytorch benchmarks."
|
||||
)
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
@@ -133,12 +130,12 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
input.to(torch_device)
|
||||
|
||||
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
|
||||
# try:
|
||||
# frontend_model = torch.compile(
|
||||
# frontend_model, mode="max-autotune", backend="inductor"
|
||||
# )
|
||||
# except RuntimeError:
|
||||
# frontend_model = HFmodel.model
|
||||
try:
|
||||
frontend_model = torch.compile(
|
||||
frontend_model, mode="max-autotune", backend="inductor"
|
||||
)
|
||||
except RuntimeError:
|
||||
frontend_model = HFmodel.model
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
@@ -152,12 +149,18 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
if self.device == "cuda":
|
||||
stats = torch.cuda.memory_stats()
|
||||
device_peak_b = stats["allocated_bytes.all.peak"]
|
||||
frontend_model.to(torch.device("cpu"))
|
||||
input.to(torch.device("cpu"))
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
if self.device == "cuda":
|
||||
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
|
||||
torch_device = torch.device("cpu")
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
@@ -166,6 +169,9 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
]
|
||||
|
||||
def benchmark_tf(self, modelname):
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
import tensorflow as tf
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
@@ -354,9 +360,11 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
device_str,
|
||||
frontend,
|
||||
import_args,
|
||||
mode="native",
|
||||
):
|
||||
self.setup_cl(inputs)
|
||||
self.import_args = import_args
|
||||
self.mode = mode
|
||||
field_names = [
|
||||
"model",
|
||||
"batch_size",
|
||||
@@ -379,7 +387,13 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
"measured_device_memory_mb",
|
||||
]
|
||||
# "frontend" must be the first element.
|
||||
engines = ["frontend", "shark_python", "shark_iree_c"]
|
||||
if self.mode == "native":
|
||||
engines = ["shark_python", "shark_iree_c"]
|
||||
if self.mode == "baseline":
|
||||
engines = ["frontend"]
|
||||
if self.mode == "all":
|
||||
engines = ["frontend", "shark_python", "shark_iree_c"]
|
||||
|
||||
if shark_args.onnx_bench == True:
|
||||
engines.append("onnxruntime")
|
||||
|
||||
@@ -407,6 +421,7 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
|
||||
for e in engines:
|
||||
engine_result = {}
|
||||
self.frontend_result = None
|
||||
if e == "frontend":
|
||||
engine_result["engine"] = frontend
|
||||
if check_requirements(frontend):
|
||||
|
||||
@@ -127,16 +127,73 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
|
||||
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
|
||||
and os.path.isfile(os.path.join(model_dir, "hash.npy"))
|
||||
):
|
||||
print(f"""Using cached models from {WORKDIR}...""")
|
||||
print(
|
||||
f"""Model artifacts for {model_name} found at {WORKDIR}..."""
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _internet_connected():
|
||||
import requests as req
|
||||
|
||||
try:
|
||||
req.get("http://1.1.1.1")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def get_git_revision_short_hash() -> str:
|
||||
import subprocess
|
||||
|
||||
if shark_args.shark_prefix is not None:
|
||||
prefix_kw = shark_args.shark_prefix
|
||||
else:
|
||||
prefix_kw = (
|
||||
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
|
||||
.decode("ascii")
|
||||
.strip()
|
||||
)
|
||||
return prefix_kw
|
||||
|
||||
|
||||
def get_sharktank_prefix():
|
||||
tank_prefix = ""
|
||||
if not _internet_connected():
|
||||
print(
|
||||
"No internet connection. Using the model already present in the tank."
|
||||
)
|
||||
tank_prefix = "none"
|
||||
else:
|
||||
desired_prefix = get_git_revision_short_hash()
|
||||
storage_client_a = storage.Client.create_anonymous_client()
|
||||
base_bucket_name = "shark_tank"
|
||||
base_bucket = storage_client_a.bucket(base_bucket_name)
|
||||
dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}")
|
||||
for blob in dir_blobs:
|
||||
dir_blob_name = blob.name.split("/")
|
||||
if desired_prefix in dir_blob_name[0]:
|
||||
tank_prefix = dir_blob_name[0]
|
||||
break
|
||||
else:
|
||||
continue
|
||||
if tank_prefix == "":
|
||||
print(
|
||||
f"shark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly."
|
||||
)
|
||||
tank_prefix = "nightly"
|
||||
return tank_prefix
|
||||
|
||||
|
||||
shark_args.shark_prefix = get_sharktank_prefix()
|
||||
|
||||
|
||||
# Downloads the torch model from gs://shark_tank dir.
|
||||
def download_model(
|
||||
model_name,
|
||||
dynamic=False,
|
||||
tank_url="gs://shark_tank/latest",
|
||||
tank_url=None,
|
||||
frontend=None,
|
||||
tuned=None,
|
||||
import_args={"batch_size": "1"},
|
||||
@@ -155,15 +212,19 @@ def download_model(
|
||||
else:
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
||||
|
||||
if not tank_url:
|
||||
tank_url = "gs://shark_tank/" + shark_args.shark_prefix
|
||||
|
||||
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
||||
if not check_dir_exists(
|
||||
model_dir_name, frontend=frontend, dynamic=dyn_str
|
||||
):
|
||||
print(
|
||||
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
|
||||
f"Downloading artifacts for model {model_name} from: {full_gs_url}"
|
||||
)
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
|
||||
elif shark_args.force_update_tank == True:
|
||||
print(
|
||||
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
|
||||
@@ -189,6 +250,7 @@ def download_model(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
except FileNotFoundError:
|
||||
print(f"Model artifact hash not found at {model_dir}.")
|
||||
upstream_hash = None
|
||||
if local_hash != upstream_hash and shark_args.update_tank == True:
|
||||
print(f"Updating artifacts for model {model_name}...")
|
||||
@@ -196,14 +258,17 @@ def download_model(
|
||||
|
||||
elif local_hash != upstream_hash:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
|
||||
"Hash does not match upstream in gs://shark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Local and upstream hashes match. Using cached model artifacts."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
tuned_str = "" if tuned is None else "_" + tuned
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
|
||||
if not os.path.exists(filename):
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
@@ -222,13 +287,3 @@ def download_model(
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
def _internet_connected():
|
||||
import requests as req
|
||||
|
||||
try:
|
||||
req.get("http://1.1.1.1")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
@@ -9,8 +9,8 @@ import hashlib
|
||||
|
||||
def create_hash(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
file_hash = hashlib.blake2b()
|
||||
while chunk := f.read(2**20):
|
||||
file_hash = hashlib.blake2b(digest_size=64)
|
||||
while chunk := f.read(2**10):
|
||||
file_hash.update(chunk)
|
||||
|
||||
return file_hash.hexdigest()
|
||||
@@ -165,8 +165,17 @@ class SharkImporter:
|
||||
if self.frontend == "torch":
|
||||
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
|
||||
mlir_file.write(mlir_data)
|
||||
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
|
||||
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
|
||||
hash_gen_attempts = 2
|
||||
for i in range(hash_gen_attempts):
|
||||
try:
|
||||
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
|
||||
except FileNotFoundError as err:
|
||||
if i < hash_gen_attempts:
|
||||
continue
|
||||
else:
|
||||
raise err
|
||||
|
||||
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
|
||||
return
|
||||
|
||||
def import_debug(
|
||||
|
||||
@@ -36,9 +36,9 @@ wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,Fal
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,False,False,"Torchvision imports issue",""
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
|
||||
|
@@ -26,8 +26,8 @@ from apps.stable_diffusion.src.utils.stable_args import (
|
||||
|
||||
def create_hash(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
file_hash = hashlib.blake2b()
|
||||
while chunk := f.read(2**20):
|
||||
file_hash = hashlib.blake2b(digest_size=64)
|
||||
while chunk := f.read(2**10):
|
||||
file_hash.update(chunk)
|
||||
|
||||
return file_hash.hexdigest()
|
||||
@@ -141,6 +141,9 @@ def save_tf_model(tf_model_list, local_tank_cache, import_args):
|
||||
get_TFhf_model,
|
||||
get_tfhf_seq2seq_model,
|
||||
)
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
import tensorflow as tf
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
|
||||
@@ -195,47 +195,47 @@ def get_vision_model(torch_model, import_args):
|
||||
import torchvision.models as models
|
||||
|
||||
default_image_size = (224, 224)
|
||||
modelname = torch_model
|
||||
if modelname == "alexnet":
|
||||
torch_model = models.alexnet(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "resnet18":
|
||||
torch_model = models.resnet18(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "resnet50":
|
||||
torch_model = models.resnet50(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "resnet50_fp16":
|
||||
torch_model = models.resnet50(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "resnet50_fp16":
|
||||
torch_model = models.resnet50(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "resnet101":
|
||||
torch_model = models.resnet101(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "squeezenet1_0":
|
||||
torch_model = models.squeezenet1_0(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "wide_resnet50_2":
|
||||
torch_model = models.wide_resnet50_2(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "mobilenet_v3_small":
|
||||
torch_model = models.mobilenet_v3_small(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "mnasnet1_0":
|
||||
torch_model = models.mnasnet1_0(weights="DEFAULT")
|
||||
input_image_size = default_image_size
|
||||
if modelname == "efficientnet_b0":
|
||||
torch_model = models.efficientnet_b0(weights="DEFAULT")
|
||||
input_image_size = (224, 224)
|
||||
if modelname == "efficientnet_b7":
|
||||
torch_model = models.efficientnet_b7(weights="DEFAULT")
|
||||
input_image_size = (600, 600)
|
||||
|
||||
vision_models_dict = {
|
||||
"alexnet": (models.alexnet(weights="DEFAULT"), default_image_size),
|
||||
"resnet18": (models.resnet18(weights="DEFAULT"), default_image_size),
|
||||
"resnet50": (models.resnet50(weights="DEFAULT"), default_image_size),
|
||||
"resnet50_fp16": (
|
||||
models.resnet50(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"resnet101": (models.resnet101(weights="DEFAULT"), default_image_size),
|
||||
"squeezenet1_0": (
|
||||
models.squeezenet1_0(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"wide_resnet50_2": (
|
||||
models.wide_resnet50_2(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"mobilenet_v3_small": (
|
||||
models.mobilenet_v3_small(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"mnasnet1_0": (
|
||||
models.mnasnet1_0(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
# EfficientNet input image size varies on the size of the model.
|
||||
"efficientnet_b0": (
|
||||
models.efficientnet_b0(weights="DEFAULT"),
|
||||
(224, 224),
|
||||
),
|
||||
"efficientnet_b7": (
|
||||
models.efficientnet_b7(weights="DEFAULT"),
|
||||
(600, 600),
|
||||
),
|
||||
}
|
||||
if isinstance(torch_model, str):
|
||||
fp16_model = None
|
||||
if "fp16" in torch_model:
|
||||
fp16_model = True
|
||||
torch_model, input_image_size = vision_models_dict[torch_model]
|
||||
fp16_model = False
|
||||
if "fp16" in modelname:
|
||||
fp16_model = True
|
||||
model = VisionModule(torch_model)
|
||||
test_input = torch.randn(
|
||||
int(import_args["batch_size"]), 3, *input_image_size
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -4,11 +4,8 @@ from shark.iree_utils._common import (
|
||||
get_supported_device_list,
|
||||
)
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
from parameterized import parameterized
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
from tank.generate_sharktank import NoImportException
|
||||
from parameterized import parameterized
|
||||
import iree.compiler as ireec
|
||||
import pytest
|
||||
import unittest
|
||||
@@ -16,8 +13,8 @@ import numpy as np
|
||||
import csv
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def load_csv_and_convert(filename, gen=False):
|
||||
@@ -140,12 +137,12 @@ class SharkModuleTester:
|
||||
self.config = config
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
import_config = {
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
shark_args.update_tank = self.update_tank
|
||||
shark_args.force_update_tank = self.force_update_tank
|
||||
shark_args.shark_prefix = self.shark_tank_prefix
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.force_update_tank = self.update_tank
|
||||
shark_args.dispatch_benchmarks = self.benchmark_dispatches
|
||||
|
||||
if self.benchmark_dispatches is not None:
|
||||
_m = self.config["model_name"].split("/")
|
||||
_m.extend([self.config["framework"], str(dynamic), device])
|
||||
@@ -169,12 +166,19 @@ class SharkModuleTester:
|
||||
if "winograd" in self.config["flags"]:
|
||||
shark_args.use_winograd = True
|
||||
|
||||
import_config = {
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.generate_sharktank import NoImportException
|
||||
|
||||
dl_gen_attempts = 2
|
||||
for i in range(dl_gen_attempts):
|
||||
try:
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
frontend=self.config["framework"],
|
||||
import_args=import_config,
|
||||
)
|
||||
@@ -190,11 +194,12 @@ class SharkModuleTester:
|
||||
"Generating OTF may require exiting the subprocess for files to be available."
|
||||
)
|
||||
break
|
||||
is_bench = True if self.benchmark is not None else False
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
device=device,
|
||||
mlir_dialect=self.config["dialect"],
|
||||
is_benchmark=self.benchmark,
|
||||
is_benchmark=is_bench,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -208,6 +213,10 @@ class SharkModuleTester:
|
||||
|
||||
result = shark_module(func_name, inputs)
|
||||
golden_out, result = self.postprocess_outputs(golden_out, result)
|
||||
if self.tf32 == "true":
|
||||
print("Validating with relaxed tolerances.")
|
||||
atol = 1e-02
|
||||
rtol = 1e-03
|
||||
try:
|
||||
np.testing.assert_allclose(
|
||||
golden_out,
|
||||
@@ -220,19 +229,25 @@ class SharkModuleTester:
|
||||
self.save_reproducers()
|
||||
if self.ci == True:
|
||||
self.upload_repro()
|
||||
if self.benchmark == True:
|
||||
self.benchmark_module(shark_module, inputs, dynamic, device)
|
||||
if self.benchmark is not None:
|
||||
self.benchmark_module(
|
||||
shark_module, inputs, dynamic, device, mode=self.benchmark
|
||||
)
|
||||
print(msg)
|
||||
pytest.xfail(
|
||||
reason=f"Numerics Mismatch: Use -s flag to print stderr during pytests."
|
||||
)
|
||||
if self.benchmark == True:
|
||||
self.benchmark_module(shark_module, inputs, dynamic, device)
|
||||
if self.benchmark is not None:
|
||||
self.benchmark_module(
|
||||
shark_module, inputs, dynamic, device, mode=self.benchmark
|
||||
)
|
||||
|
||||
if self.save_repro == True:
|
||||
self.save_reproducers()
|
||||
|
||||
def benchmark_module(self, shark_module, inputs, dynamic, device):
|
||||
def benchmark_module(
|
||||
self, shark_module, inputs, dynamic, device, mode="native"
|
||||
):
|
||||
model_config = {
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
@@ -248,6 +263,7 @@ class SharkModuleTester:
|
||||
device,
|
||||
self.config["framework"],
|
||||
import_args=model_config,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
def save_reproducers(self):
|
||||
@@ -319,7 +335,12 @@ class SharkModuleTest(unittest.TestCase):
|
||||
self.module_tester.update_tank = self.pytestconfig.getoption(
|
||||
"update_tank"
|
||||
)
|
||||
self.module_tester.tank_url = self.pytestconfig.getoption("tank_url")
|
||||
self.module_tester.force_update_tank = self.pytestconfig.getoption(
|
||||
"force_update_tank"
|
||||
)
|
||||
self.module_tester.shark_tank_prefix = self.pytestconfig.getoption(
|
||||
"tank_prefix"
|
||||
)
|
||||
self.module_tester.benchmark_dispatches = self.pytestconfig.getoption(
|
||||
"benchmark_dispatches"
|
||||
)
|
||||
@@ -336,19 +357,26 @@ class SharkModuleTest(unittest.TestCase):
|
||||
if config["xfail_vkm"] == "True" and device in ["metal", "vulkan"]:
|
||||
pytest.xfail(reason=config["xfail_reason"])
|
||||
|
||||
if os.name == "nt" and "enabled_windows" not in config["xfail_other"]:
|
||||
if (
|
||||
self.pytestconfig.getoption("ci") == True
|
||||
and os.name == "nt"
|
||||
and "enabled_windows" not in config["xfail_other"]
|
||||
):
|
||||
pytest.xfail(reason="this model skipped on windows")
|
||||
|
||||
# Special cases that need to be marked.
|
||||
if "macos" in config["xfail_other"] and device in [
|
||||
"metal",
|
||||
"vulkan",
|
||||
]:
|
||||
if get_vulkan_triple_flag() is not None:
|
||||
if "m1-moltenvk-macos" in get_vulkan_triple_flag():
|
||||
pytest.xfail(
|
||||
reason="conv-related issue on MacStudio, returns VK_ERROR_DEVICE_LOST."
|
||||
)
|
||||
if (
|
||||
"macos" in config["xfail_other"]
|
||||
and device
|
||||
in [
|
||||
"metal",
|
||||
"vulkan",
|
||||
]
|
||||
and sys.platform == "darwin"
|
||||
):
|
||||
pytest.skip(
|
||||
reason="conv-related issue on MacStudio, returns VK_ERROR_DEVICE_LOST."
|
||||
)
|
||||
if (
|
||||
config["model_name"]
|
||||
in [
|
||||
|
||||
Reference in New Issue
Block a user