Files
InvokeAI/invokeai/backend/model_manager/load/model_loaders/vae.py
Lincoln Stein a23dedd2ee make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
2024-03-01 10:42:33 +11:00

69 lines
2.7 KiB
Python

# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
import safetensors
import torch
from omegaconf import DictConfig, OmegaConf
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
# TO DO: check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
else:
config_file = (
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
)
if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
assert isinstance(ckpt_config, DictConfig)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
)
vae_model.to(self._torch_dtype) # set precision appropriately
vae_model.save_pretrained(output_path, safe_serialization=True)
return output_path