mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD-CLI] Clean up vmfbs if a retry method fails
-- This commit cleans up vmfb files generated as a result of retry method. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
committed by
Abhishek Varma
parent
4be75d4418
commit
1118b4b651
@@ -203,9 +203,15 @@ class SharkifyStableDiffusionModel:
|
||||
return shark_clip
|
||||
|
||||
def __call__(self):
|
||||
from utils import get_vmfb_path_name
|
||||
from stable_args import args
|
||||
import traceback
|
||||
import traceback, functools, operator, os
|
||||
|
||||
model_name = ["clip", "base_vae" if self.base_vae else "vae", "unet"]
|
||||
vmfb_path = [
|
||||
get_vmfb_path_name(model + self.model_name)[0]
|
||||
for model in model_name
|
||||
]
|
||||
for model_id in base_models:
|
||||
self.inputs = get_input_info(
|
||||
base_models[model_id],
|
||||
@@ -215,12 +221,22 @@ class SharkifyStableDiffusionModel:
|
||||
self.batch_size,
|
||||
)
|
||||
try:
|
||||
compiled_clip = self.get_clip()
|
||||
compiled_unet = self.get_unet()
|
||||
compiled_vae = self.get_vae()
|
||||
compiled_clip = self.get_clip()
|
||||
except Exception as e:
|
||||
if args.enable_stack_trace:
|
||||
traceback.print_exc()
|
||||
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
|
||||
all_vmfb_present = functools.reduce(
|
||||
operator.__and__, vmfb_present
|
||||
)
|
||||
# We need to delete vmfbs only if some of the models were compiled.
|
||||
if not all_vmfb_present:
|
||||
for i in range(len(vmfb_path)):
|
||||
if vmfb_present[i]:
|
||||
os.remove(vmfb_path[i])
|
||||
print("Deleted: ", vmfb_path[i])
|
||||
print("Retrying with a different base model configuration")
|
||||
continue
|
||||
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
|
||||
|
||||
@@ -14,15 +14,20 @@ from sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
|
||||
|
||||
def get_vmfb_path_name(model_name):
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
return [vmfb_path, extended_name]
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
|
||||
Reference in New Issue
Block a user