From 3dfcf9a8694fabee3109032a14fe6e4fc3ce80fb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Sep 2025 15:17:44 +1000 Subject: [PATCH] feat(mm): port vae to new API --- invokeai/backend/model_manager/config.py | 122 +++++++++++++++++- .../backend/model_manager/legacy_probe.py | 48 ------- .../backend/model_manager/model_on_disk.py | 16 +++ 3 files changed, 134 insertions(+), 52 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 7b0bcea8e5..bcc03390fc 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -23,6 +23,7 @@ Validation errors will raise an InvalidModelConfigException error. # pyright: reportIncompatibleVariableOverride=false import json import logging +import re import time from abc import ABC, abstractmethod from enum import Enum @@ -73,6 +74,15 @@ class InvalidModelConfigException(Exception): DEFAULTS_PRECISION = Literal["fp16", "fp32"] +def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]: + if "_class_name" in config: + return config["_class_name"] + elif "architectures" in config: + return config["architectures"][0] + else: + return None + + class SubmodelDefinition(BaseModel): path_or_prefix: str model_type: ModelType @@ -578,18 +588,122 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): } -class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase): +class VAEConfigBase(CheckpointConfigBase): + type: Literal[ModelType.VAE] = ModelType.VAE + + +class VAECheckpointConfig(VAEConfigBase, ModelConfigBase): """Model config for standalone VAE models.""" - type: Literal[ModelType.VAE] = ModelType.VAE + format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + + KEY_PREFIXES: ClassVar = {"encoder.conv_in", "decoder.conv_in"} + + @classmethod + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_vae_override = overrides.get("type") is ModelType.VAE + is_checkpoint_override = overrides.get("format") is ModelFormat.Checkpoint + + if is_vae_override and is_checkpoint_override: + return MatchCertainty.OVERRIDE + + if mod.path.is_dir(): + return MatchCertainty.NEVER + + if mod.has_keys_starting_with(cls.KEY_PREFIXES): + return MatchCertainty.MAYBE + + return MatchCertainty.NEVER + + @classmethod + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: + base = cls.get_base_type(mod) + return {"base": base} + + @classmethod + def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType: + # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name + for regexp, basetype in [ + (r"xl", BaseModelType.StableDiffusionXL), + (r"sd2", BaseModelType.StableDiffusion2), + (r"vae", BaseModelType.StableDiffusion1), + (r"FLUX.1-schnell_ae", BaseModelType.Flux), + ]: + if re.search(regexp, mod.path.name, re.IGNORECASE): + return basetype + + raise InvalidModelConfigException("Cannot determine base type") -class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase): +class VAEDiffusersConfig(VAEConfigBase, ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" - type: Literal[ModelType.VAE] = ModelType.VAE format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"} + + @classmethod + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_vae_override = overrides.get("type") is ModelType.VAE + is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers + + if is_vae_override and is_diffusers_override: + return MatchCertainty.OVERRIDE + + if mod.path.is_file(): + return MatchCertainty.NEVER + + try: + config = cls.get_config(mod) + class_name = get_class_name_from_config(config) + if class_name in cls.CLASS_NAMES: + return MatchCertainty.EXACT + except Exception: + pass + + return MatchCertainty.NEVER + + @classmethod + def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]: + config_path = mod.path / "config.json" + with open(config_path, "r") as file: + return json.load(file) + + @classmethod + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: + base = cls.get_base_type(mod) + return {"base": base} + + @classmethod + def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType: + if cls._config_looks_like_sdxl(mod): + return BaseModelType.StableDiffusionXL + elif cls._name_looks_like_sdxl(mod): + return BaseModelType.StableDiffusionXL + else: + # We do not support diffusers VAEs for any other base model at this time... YOLO + return BaseModelType.StableDiffusion1 + + @classmethod + def _config_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool: + config = cls.get_config(mod) + # Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. + return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] + + @classmethod + def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool: + # Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down + # by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best + # we can do is guess based on name. + return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE)) + + @classmethod + def _guess_name(cls, mod: ModelOnDisk) -> str: + name = mod.path.name + if name == "vae": + name = mod.path.parent.name + return name + class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for ControlNet models (diffusers version).""" diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index b6bda0c15f..5955c8af2c 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -1,5 +1,4 @@ import json -import re from pathlib import Path from typing import Any, Callable, Dict, Literal, Optional, Union @@ -654,21 +653,6 @@ class PipelineCheckpointProbe(CheckpointProbeBase): return SchedulerPredictionType.Epsilon -class VaeCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - # VAEs of all base types have the same structure, so we wimp out and - # guess using the name. - for regexp, basetype in [ - (r"xl", BaseModelType.StableDiffusionXL), - (r"sd2", BaseModelType.StableDiffusion2), - (r"vae", BaseModelType.StableDiffusion1), - (r"FLUX.1-schnell_ae", BaseModelType.Flux), - ]: - if re.search(regexp, self.model_path.name, re.IGNORECASE): - return basetype - raise InvalidModelConfigException("Cannot determine base type") - - class LoRACheckpointProbe(CheckpointProbeBase): """Class for LoRA checkpoints.""" @@ -895,36 +879,6 @@ class PipelineFolderProbe(FolderProbeBase): return ModelVariantType.Normal -class VaeFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self._config_looks_like_sdxl(): - return BaseModelType.StableDiffusionXL - elif self._name_looks_like_sdxl(): - # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down - # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. - return BaseModelType.StableDiffusionXL - else: - return BaseModelType.StableDiffusion1 - - def _config_looks_like_sdxl(self) -> bool: - # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. - config_file = self.model_path / "config.json" - if not config_file.exists(): - raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") - with open(config_file, "r") as file: - config = json.load(file) - return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - - def _name_looks_like_sdxl(self) -> bool: - return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) - - def _guess_name(self) -> str: - name = self.model_path.name - if name == "vae": - name = self.model_path.parent.name - return name - - class T5EncoderFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: return BaseModelType.Any @@ -1080,7 +1034,6 @@ class T2IAdapterFolderProbe(FolderProbeBase): # Register probe classes ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe) @@ -1093,7 +1046,6 @@ ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) diff --git a/invokeai/backend/model_manager/model_on_disk.py b/invokeai/backend/model_manager/model_on_disk.py index 502ca596a6..9de78e53b3 100644 --- a/invokeai/backend/model_manager/model_on_disk.py +++ b/invokeai/backend/model_manager/model_on_disk.py @@ -128,3 +128,19 @@ class ModelOnDisk: f"Please specify the intended file using the 'path' argument" ) return path + + def has_keys_exact(self, keys: set[str], path: Optional[Path] = None) -> bool: + state_dict = self.load_state_dict(path) + return keys.issubset({key for key in state_dict.keys() if isinstance(key, str)}) + + def has_keys_starting_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool: + state_dict = self.load_state_dict(path) + return any( + any(key.startswith(prefix) for prefix in prefixes) for key in state_dict.keys() if isinstance(key, str) + ) + + def has_keys_ending_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool: + state_dict = self.load_state_dict(path) + return any( + any(key.endswith(suffix) for suffix in prefixes) for key in state_dict.keys() if isinstance(key, str) + )