[SD] Fix vmfb locating bug

-- This commit fixes a bug in vmfb caching due to vae_encoder and also
   involves a minor NFC change in the code.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-02-10 11:53:54 +00:00
committed by Abhishek Varma
parent 99aa77d036
commit 591bbcd058

View File

@@ -14,7 +14,7 @@ from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys, functools, operator
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
@@ -424,18 +424,26 @@ def fetch_or_delete_vmfbs(
for model in extended_model_name
]
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(operator.__and__, vmfb_present)
compiled_models = [None] * 4 if need_vae_encode else [None] * 3
all_vmfb_present = True
compiled_models = []
for i in range(3):
all_vmfb_present = all_vmfb_present and vmfb_present[i]
compiled_models.append(None)
if need_vae_encode:
all_vmfb_present = all_vmfb_present and vmfb_present[3]
compiled_models.append(None)
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
for i in range(len(compiled_models)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
else:
for i in range(len(vmfb_path)):
model_name = [model for model in extended_model_name.keys()]
for i in range(len(compiled_models)):
compiled_models[i] = load_vmfb(
vmfb_path[i], extended_model_name[i], precision
vmfb_path[i], model_name[i], precision
)
return compiled_models