From 1c62ae461e410c9c933571351cbae2e719cf3c68 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 18 Jan 2023 12:15:57 -0500 Subject: [PATCH] fix vae safetensor loading --- ldm/invoke/model_manager.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 8aeeda650a..671e00aacd 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -359,10 +359,14 @@ class ModelManager(object): vae = os.path.normpath(os.path.join(Globals.root,vae)) if os.path.exists(vae): print(f' | Loading VAE weights from: {vae}') - vae_ckpt = safetensors.torch.load_file(vae) \ - if vae.endswith('.safetensors') \ - else torch.load(vae, map_location="cpu") - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + vae_ckpt = None + vae_dict = None + if vae.endswith('.safetensors'): + vae_ckpt = safetensors.torch.load_file(vae) + vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"} + else: + vae_ckpt = torch.load(vae, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict, strict=False) else: print(f' | VAE file {vae} not found. Skipping.')