[SHARK][SD] Add --local_tank_cache flag in the stable diffusion

This flag can be used to set local shark_tank cache directory.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-12-14 19:27:35 +05:30
parent 986c126a5c
commit e67ea31ee2
4 changed files with 20 additions and 0 deletions

View File

@@ -103,6 +103,12 @@ p.add_argument(
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,

View File

@@ -40,6 +40,10 @@ def _compile_module(shark_module, model_name, extra_args=[]):
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_model
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
mlir_model, func_name, inputs, golden_out = download_model(
model_name,

View File

@@ -110,6 +110,12 @@ p.add_argument(
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="4294967296",

View File

@@ -38,6 +38,10 @@ def _compile_module(args, shark_module, model_name, extra_args=[]):
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(args, tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_model
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
mlir_model, func_name, inputs, golden_out = download_model(
model_name, tank_url=tank_url, frontend="torch"