Add --local_tank_cache flag and update requirements. (#368)

* Add --local_tank_cache flag and update requirements.

* Update requirements-importer.txt
This commit is contained in:
Ean Garvey
2022-09-28 03:02:59 -05:00
committed by GitHub
parent 28daf410b6
commit 9035a2eed3
8 changed files with 32 additions and 9 deletions

View File

@@ -90,7 +90,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark --ci --ci_sha=${SHORT_SHA} tank/test_models.py -k cpu
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
@@ -100,7 +100,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark --ci --ci_sha=${SHORT_SHA} tank/test_models.py -k cuda
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
@@ -110,4 +110,4 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --ci --ci_sha=${SHORT_SHA} tank/test_models.py -k 'vulkan' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k vulkan

View File

@@ -42,6 +42,12 @@ def pytest_addoption(parser):
default="None",
help="Passes the github SHA of the CI workflow to include in google storage directory for reproduction artifacts.",
)
parser.addoption(
"--local_tank_cache",
action="store",
default="",
help="Specify the directory in which all downloaded shark_tank artifacts will be cached.",
)
parser.addoption(
"--tank_url",
type=str,

View File

@@ -19,7 +19,7 @@ tensorflow-macos
tensorflow-metal
#tf-models-nightly
#tensorflow-text-nightly
transformers==4.18.0
transformers
tensorflow-probability
#jax[cpu]

View File

@@ -30,6 +30,7 @@ Pillow
lit
pyyaml
python-dateutil
sacremoses
# web dependecies.
gradio

View File

@@ -7,9 +7,6 @@ tqdm
# SHARK Downloader
gsutil
# generate_sharktank
transformers==4.18.0
# Testing
pytest
pytest-xdist

View File

@@ -87,5 +87,10 @@ parser.add_argument(
action="store_true",
help="When enabled, SHARK downloader will update local shark_tank if local hash is different from latest upstream hash.",
)
parser.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
shark_args, unknown = parser.parse_known_args()

View File

@@ -34,15 +34,25 @@ input_type_to_np_dtype = {
# Save the model in the home local so it needn't be fetched everytime in the CI.
home = str(Path.home())
alt_path = os.path.join(os.path.dirname(__file__), "../gen_shark_tank/")
custom_path = shark_args.local_tank_cache
if os.path.exists(alt_path):
WORKDIR = alt_path
print(
f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."
)
if custom_path:
if not os.path.exists(custom_path):
os.mkdir(custom_path)
WORKDIR = custom_path
print(f"Using {WORKDIR} as local shark_tank cache directory.")
else:
WORKDIR = os.path.join(home, ".local/shark_tank/")
print(WORKDIR)
print(
f"shark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache="
" pytest flag"
)
# Checks whether the directory and files exists.
def check_dir_exists(model_name, frontend="torch", dynamic=""):

View File

@@ -130,6 +130,7 @@ class SharkModuleTester:
self.config = config
def create_and_check_module(self, dynamic, device):
shark_args.local_tank_cache = self.local_tank_cache
if self.config["framework"] == "tf":
model, func_name, inputs, golden_out = download_tf_model(
self.config["model_name"],
@@ -262,6 +263,9 @@ class SharkModuleTest(unittest.TestCase):
self.module_tester.tf32 = self.pytestconfig.getoption("tf32")
self.module_tester.ci = self.pytestconfig.getoption("ci")
self.module_tester.ci_sha = self.pytestconfig.getoption("ci_sha")
self.module_tester.local_tank_cache = self.pytestconfig.getoption(
"local_tank_cache"
)
self.module_tester.tank_url = self.pytestconfig.getoption("tank_url")
if (
config["model_name"] == "distilbert-base-uncased"