Compare commits

...

1 Commits

Author SHA1 Message Date
Ean Garvey
2bd2bfa8b9 Allow SD model wrapper to be called for sharktank gen. 2023-04-18 21:09:36 -05:00
2 changed files with 45 additions and 1 deletions

View File

@@ -666,3 +666,47 @@ class SharkifyStableDiffusionModel:
return compiled_stencil_adaptor
except Exception as e:
sys.exit(e)
def __call__(self):
from utils import get_vmfb_path_name
from stable_args import args
import traceback, functools, operator, os
model_name = ["clip", "base_vae" if self.base_vae else "vae", "unet"]
vmfb_path = [
get_vmfb_path_name(model + self.model_name)[0]
for model in model_name
]
for model_id in base_models:
self.inputs = self.get_input_info_for(
base_models[model_id],
)
try:
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(
operator.__and__, vmfb_present
)
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
print("Retrying with a different base model configuration")
continue
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -13,7 +13,7 @@ We currently make use of the [AI-Render Plugin](https://github.com/benrugg/AI-Re
.\shark_sd_<date>_<ver>.exe --api
## For example:
.\shark_sd_20230411_671.exe --api --server_port=8082
.\shark_sd_20230411_684.exe --api --server_port=8082
## From a the base directory of a source clone of SHARK:
./setup_venv.ps1