diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml index e4dbc217..5929ee70 100644 --- a/.github/workflows/test-models.yml +++ b/.github/workflows/test-models.yml @@ -136,7 +136,7 @@ jobs: export DYLD_LIBRARY_PATH=/usr/local/lib/ echo $PATH pip list | grep -E "torch|iree" - pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan + pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank - name: Validate Vulkan Models (a100) if: matrix.suite == 'vulkan' && matrix.os == 'a100' diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 42f2f5b3..8d76a60b 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -57,11 +57,13 @@ def _compile_module(shark_module, model_name, extra_args=[]): # Downloads the model from shark_tank and returns the shark_module. def get_shark_model(tank_url, model_name, extra_args=[]): - from shark.shark_downloader import download_model from shark.parser import shark_args # Set local shark_tank cache directory. shark_args.local_tank_cache = args.local_tank_cache + + from shark.shark_downloader import download_model + if "cuda" in args.device: shark_args.enable_tf32 = True diff --git a/generate_sharktank.py b/generate_sharktank.py index b44aaa24..fdea2d3d 100644 --- a/generate_sharktank.py +++ b/generate_sharktank.py @@ -2,11 +2,10 @@ """SHARK Tank""" # python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url] # will generate local shark tank folder like this: -# HOME -# /.local -# /shark_tank -# /albert_lite_base -# /...model_name... +# /SHARK +# /gen_shark_tank +# /albert_lite_base +# /...model_name... # import os diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index f3786805..c1daf265 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -89,7 +89,7 @@ if custom_path is not None: print(f"Using {WORKDIR} as local shark_tank cache directory.") -if os.path.exists(alt_path): +elif os.path.exists(alt_path): WORKDIR = alt_path print( f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."