diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 063354b2..2a2f521f 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -14,7 +14,7 @@ from shark.iree_utils.gpu_utils import get_cuda_sm_cc from apps.stable_diffusion.src.utils.stable_args import args from apps.stable_diffusion.src.utils.resources import opt_flags from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation -import sys, functools, operator +import sys from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( load_pipeline_from_original_stable_diffusion_ckpt, ) @@ -424,18 +424,26 @@ def fetch_or_delete_vmfbs( for model in extended_model_name ] vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path] - all_vmfb_present = functools.reduce(operator.__and__, vmfb_present) - compiled_models = [None] * 4 if need_vae_encode else [None] * 3 + all_vmfb_present = True + compiled_models = [] + for i in range(3): + all_vmfb_present = all_vmfb_present and vmfb_present[i] + compiled_models.append(None) + if need_vae_encode: + all_vmfb_present = all_vmfb_present and vmfb_present[3] + compiled_models.append(None) + # We need to delete vmfbs only if some of the models were compiled. if not all_vmfb_present: - for i in range(len(vmfb_path)): + for i in range(len(compiled_models)): if vmfb_present[i]: os.remove(vmfb_path[i]) print("Deleted: ", vmfb_path[i]) else: - for i in range(len(vmfb_path)): + model_name = [model for model in extended_model_name.keys()] + for i in range(len(compiled_models)): compiled_models[i] = load_vmfb( - vmfb_path[i], extended_model_name[i], precision + vmfb_path[i], model_name[i], precision ) return compiled_models