autoconvert legacy VAEs (#3235)

This draft PR implements a system in which if a diffusers model is
loaded, and the model manager detects that the user tried to assign a
legacy checkpoint VAE to the model, the checkpoint will be converted to
a diffusers VAE in RAM.

It is draft because it has not been carefully tested yet, and there are
some edge cases that are not handled properly.
This commit is contained in:
blessedcoolant
2023-05-19 12:26:25 +12:00
committed by GitHub
2 changed files with 40 additions and 5 deletions

View File

@@ -620,7 +620,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
return new_checkpoint
def convert_ldm_vae_state_dict(vae_state_dict, config):
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]

View File

@@ -9,7 +9,6 @@ from __future__ import annotations
import contextlib
import gc
import hashlib
import io
import os
import re
import sys
@@ -31,11 +30,10 @@ from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from ldm.invoke.devices import CPU_DEVICE
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ldm.invoke.globals import Globals, global_cache_dir
from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name
from ldm.util import ask_user, download_with_resume, url_attachment_name
class SDLegacyType(Enum):
@@ -370,8 +368,9 @@ class ModelManager(object):
print(
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
)
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
from .ckpt_to_diffuser import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -1230,6 +1229,13 @@ class ModelManager(object):
return vae_path
def _load_vae(self, vae_config) -> AutoencoderKL:
# Handle the common case of a user shoving a VAE .ckpt into
# the vae field for a diffusers. We convert it into diffusers
# format and use it.
if type(vae_config) in [str,Path]:
return self.convert_vae(vae_config)
vae_args = {}
try:
name_or_path = self.model_name_or_path(vae_config)
@@ -1277,6 +1283,32 @@ class ModelManager(object):
return vae
@staticmethod
def convert_vae(vae_path: Union[Path,str])->AutoencoderKL:
print(f" | A checkpoint VAE was detected. Converting to diffusers format.")
vae_path = Path(Globals.root,vae_path).resolve()
from .ckpt_to_diffuser import (
create_vae_diffusers_config,
convert_ldm_vae_state_dict,
)
vae_path = Path(vae_path)
if vae_path.suffix in ['.pt','.ckpt']:
vae_state_dict = torch.load(vae_path, map_location="cpu")
else:
vae_state_dict = safetensors.torch.load_file(vae_path)
if 'state_dict' in vae_state_dict:
vae_state_dict = vae_state_dict['state_dict']
# TODO: see if this works with 1.x inpaint models and 2.x models
config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml")
original_conf = OmegaConf.load(config_file_path)
vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix
diffusers_vae = convert_ldm_vae_state_dict(vae_state_dict,vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(diffusers_vae)
return vae
@staticmethod
def _delete_model_from_cache(repo_id):
cache_info = scan_cache_dir(global_cache_dir("diffusers"))