mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Add pytest options to save reproducers. (#350)
* Add pytest options to save and/or upload reproducers. * pass shark_module to benchmark method.
This commit is contained in:
6
.github/workflows/test-models.yml
vendored
6
.github/workflows/test-models.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark tank/test_models.py -k cpu
|
||||
pytest --benchmark --ci tank/test_models.py -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
|
||||
@@ -100,7 +100,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark tank/test_models.py -k "cuda" --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
|
||||
pytest --benchmark --ci tank/test_models.py -k "cuda" --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
|
||||
@@ -110,4 +110,4 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest tank/test_models.py -k 'vulkan' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
|
||||
pytest --ci tank/test_models.py -k 'vulkan' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
|
||||
|
||||
15
conftest.py
15
conftest.py
@@ -18,22 +18,15 @@ def pytest_addoption(parser):
|
||||
default="False",
|
||||
help="Use TensorFloat-32 calculations.",
|
||||
)
|
||||
# The following options are deprecated and pending removal.
|
||||
parser.addoption(
|
||||
"--save_mlir",
|
||||
"--save_repro",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to save input MLIR",
|
||||
help="Pass option to save reproduction artifacts to SHARK/shark_tmp/test_case/",
|
||||
)
|
||||
parser.addoption(
|
||||
"--save_vmfb",
|
||||
"--ci",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to save IREE output .vmfb",
|
||||
)
|
||||
parser.addoption(
|
||||
"--save_temps",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Saves IREE reproduction artifacts for filing upstream issues.",
|
||||
help="Enables uploading of reproduction artifacts upon test case failure during iree-compile or validation.",
|
||||
)
|
||||
|
||||
@@ -12,10 +12,14 @@ from shark.shark_downloader import (
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
import iree.compiler as ireec
|
||||
import pytest
|
||||
import unittest
|
||||
import numpy as np
|
||||
import csv
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
def load_csv_and_convert(filename, gen=False):
|
||||
@@ -148,31 +152,72 @@ class SharkModuleTester:
|
||||
mlir_dialect=self.config["dialect"],
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
try:
|
||||
shark_module.compile()
|
||||
except:
|
||||
if any([self.ci, self.save_repro]) == True:
|
||||
self.save_reproducers()
|
||||
if self.ci == True:
|
||||
self.upload_repro()
|
||||
raise
|
||||
|
||||
result = shark_module.forward(inputs)
|
||||
golden_out, result = self.postprocess_outputs(golden_out, result)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
golden_out,
|
||||
result,
|
||||
rtol=self.config["rtol"],
|
||||
atol=self.config["atol"],
|
||||
)
|
||||
try:
|
||||
np.testing.assert_allclose(
|
||||
golden_out,
|
||||
result,
|
||||
rtol=self.config["rtol"],
|
||||
atol=self.config["atol"],
|
||||
)
|
||||
except AssertionError:
|
||||
if any([self.ci, self.save_repro]) == True:
|
||||
self.save_reproducers()
|
||||
if self.ci == True:
|
||||
self.upload_repro()
|
||||
if self.benchmark == True:
|
||||
self.benchmark_module(shark_module, inputs, dynamic, device)
|
||||
raise
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_args.enable_tf32 = self.tf32
|
||||
if shark_args.enable_tf32 == True:
|
||||
shark_module.compile()
|
||||
shark_args.enable_tf32 = False
|
||||
self.benchmark_module(shark_module, inputs, dynamic, device)
|
||||
|
||||
shark_args.onnx_bench = self.onnx_bench
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
(inputs),
|
||||
self.config["model_name"],
|
||||
dynamic,
|
||||
device,
|
||||
self.config["framework"],
|
||||
)
|
||||
if self.save_repro == True:
|
||||
self.save_reproducers()
|
||||
|
||||
def benchmark_module(self, shark_module, inputs, dynamic, device):
|
||||
shark_args.enable_tf32 = self.tf32
|
||||
if shark_args.enable_tf32 == True:
|
||||
shark_module.compile()
|
||||
shark_args.enable_tf32 = False
|
||||
|
||||
shark_args.onnx_bench = self.onnx_bench
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
(inputs),
|
||||
self.config["model_name"],
|
||||
dynamic,
|
||||
device,
|
||||
self.config["framework"],
|
||||
)
|
||||
|
||||
def save_reproducers(self):
|
||||
# Saves contents of IREE TempFileSaver temporary directory to ./shark_tmp/saved/<test_case>.
|
||||
src = self.temp_dir
|
||||
trg = f"./shark_tmp/saved/{self.tmp_prefix}"
|
||||
if not os.path.isdir("./shark_tmp/saved/"):
|
||||
os.mkdir("./shark_tmp/saved/")
|
||||
if not os.path.isdir(trg):
|
||||
os.mkdir(trg)
|
||||
files = os.listdir(src)
|
||||
for fname in files:
|
||||
shutil.copy2(os.path.join(src, fname), trg)
|
||||
|
||||
def upload_repro(self):
|
||||
import subprocess
|
||||
|
||||
bashCommand = f"gsutil cp -r ./shark_tmp/saved/{self.tmp_prefix}/* gs://shark-public/builder/repro_artifacts/"
|
||||
process = subprocess.run(bashCommand.split())
|
||||
|
||||
def postprocess_outputs(self, golden_out, result):
|
||||
# Prepares result tensors of forward pass and golden values for comparison, when needed.
|
||||
@@ -202,11 +247,14 @@ class SharkModuleTest(unittest.TestCase):
|
||||
def test_module(self, dynamic, device, config):
|
||||
self.module_tester = SharkModuleTester(config)
|
||||
self.module_tester.benchmark = self.pytestconfig.getoption("benchmark")
|
||||
self.module_tester.save_repro = self.pytestconfig.getoption(
|
||||
"save_repro"
|
||||
)
|
||||
self.module_tester.onnx_bench = self.pytestconfig.getoption(
|
||||
"onnx_bench"
|
||||
)
|
||||
self.module_tester.tf32 = self.pytestconfig.getoption("tf32")
|
||||
|
||||
self.module_tester.ci = self.pytestconfig.getoption("ci")
|
||||
if (
|
||||
config["model_name"] == "facebook/convnext-tiny-224"
|
||||
and device == "cuda"
|
||||
@@ -277,4 +325,18 @@ class SharkModuleTest(unittest.TestCase):
|
||||
reason="Dynamic shapes not supported for this framework."
|
||||
)
|
||||
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
safe_name = (
|
||||
f"{config['model_name']}_{config['framework']}_{dynamic}_{device}"
|
||||
)
|
||||
self.module_tester.tmp_prefix = safe_name.replace("/", "_")
|
||||
|
||||
if not os.path.isdir("./shark_tmp/"):
|
||||
os.mkdir("./shark_tmp/")
|
||||
|
||||
tempdir = tempfile.TemporaryDirectory(
|
||||
prefix=self.module_tester.tmp_prefix, dir="./shark_tmp/"
|
||||
)
|
||||
self.module_tester.temp_dir = tempdir.name
|
||||
|
||||
with ireec.tools.TempFileSaver(tempdir.name):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
Reference in New Issue
Block a user