mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
autoconvert legacy VAEs (#3235)
This draft PR implements a system in which if a diffusers model is loaded, and the model manager detects that the user tried to assign a legacy checkpoint VAE to the model, the checkpoint will be converted to a diffusers VAE in RAM. It is draft because it has not been carefully tested yet, and there are some edge cases that are not handled properly.
This commit is contained in:
@@ -620,7 +620,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
|
||||
return new_checkpoint
|
||||
|
||||
def convert_ldm_vae_state_dict(vae_state_dict, config):
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
|
||||
@@ -9,7 +9,6 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@@ -31,11 +30,10 @@ from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.devices import CPU_DEVICE
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name
|
||||
from ldm.util import ask_user, download_with_resume, url_attachment_name
|
||||
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
@@ -370,8 +368,9 @@ class ModelManager(object):
|
||||
print(
|
||||
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
||||
)
|
||||
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
from .ckpt_to_diffuser import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
@@ -1230,6 +1229,13 @@ class ModelManager(object):
|
||||
return vae_path
|
||||
|
||||
def _load_vae(self, vae_config) -> AutoencoderKL:
|
||||
|
||||
# Handle the common case of a user shoving a VAE .ckpt into
|
||||
# the vae field for a diffusers. We convert it into diffusers
|
||||
# format and use it.
|
||||
if type(vae_config) in [str,Path]:
|
||||
return self.convert_vae(vae_config)
|
||||
|
||||
vae_args = {}
|
||||
try:
|
||||
name_or_path = self.model_name_or_path(vae_config)
|
||||
@@ -1277,6 +1283,32 @@ class ModelManager(object):
|
||||
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def convert_vae(vae_path: Union[Path,str])->AutoencoderKL:
|
||||
print(f" | A checkpoint VAE was detected. Converting to diffusers format.")
|
||||
vae_path = Path(Globals.root,vae_path).resolve()
|
||||
|
||||
from .ckpt_to_diffuser import (
|
||||
create_vae_diffusers_config,
|
||||
convert_ldm_vae_state_dict,
|
||||
)
|
||||
|
||||
vae_path = Path(vae_path)
|
||||
if vae_path.suffix in ['.pt','.ckpt']:
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
else:
|
||||
vae_state_dict = safetensors.torch.load_file(vae_path)
|
||||
if 'state_dict' in vae_state_dict:
|
||||
vae_state_dict = vae_state_dict['state_dict']
|
||||
# TODO: see if this works with 1.x inpaint models and 2.x models
|
||||
config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml")
|
||||
original_conf = OmegaConf.load(config_file_path)
|
||||
vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix
|
||||
diffusers_vae = convert_ldm_vae_state_dict(vae_state_dict,vae_config)
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(diffusers_vae)
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def _delete_model_from_cache(repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir("diffusers"))
|
||||
|
||||
Reference in New Issue
Block a user