autoconvert ckpt VAEs assigned to diffusers models

This commit is contained in:
Lincoln Stein
2023-04-19 17:44:27 -04:00
parent ce22a1577c
commit 23d9361528
2 changed files with 19 additions and 11 deletions

View File

@@ -620,10 +620,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = _convert_ldm_vae_checkpoint(vae_state_dict,config)
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
return new_checkpoint
def _convert_ldm_vae_checkpoint(vae_state_dict, config):
def convert_ldm_vae_state_dict(vae_state_dict, config):
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]

View File

@@ -30,12 +30,6 @@ from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from .ckpt_to_diffuser import (
load_pipeline_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,
convert_ldm_vae_checkpoint,
)
from ldm.invoke.devices import CPU_DEVICE
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ldm.invoke.globals import Globals, global_cache_dir
@@ -374,7 +368,10 @@ class ModelManager(object):
print(
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
)
from .ckpt_to_diffuser import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -1287,17 +1284,28 @@ class ModelManager(object):
return vae
@staticmethod
def convert_vae(vae_path: Union[Path,str])->AutoencoderKL:
print(f" | A checkpoint VAE was detected. Converting to diffusers format.")
vae_path = Path(Globals.root,vae_path).resolve()
from .ckpt_to_diffuser import (
create_vae_diffusers_config,
convert_ldm_vae_state_dict,
)
vae_path = Path(vae_path)
if vae_path.suffix in ['.pt','.ckpt']:
vae_state_dict = torch.load(vae_path)
vae_state_dict = torch.load(vae_path, map_location="cpu")
else:
vae_state_dict = safetensors.torch.load_file(vae_path)
if 'state_dict' in vae_state_dict:
vae_state_dict = vae_state_dict['state_dict']
# TODO: see if this works with 1.x inpaint models and 2.x models
config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml")
original_conf = OmegaConf.load(config_file_path)
vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix
diffusers_vae = convert_ldm_vae_checkpoint(vae_state_dict,vae_config)
diffusers_vae = convert_ldm_vae_state_dict(vae_state_dict,vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(diffusers_vae)
return vae