mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): port vae to new API
This commit is contained in:
@@ -23,6 +23,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
# pyright: reportIncompatibleVariableOverride=false
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -73,6 +74,15 @@ class InvalidModelConfigException(Exception):
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]:
|
||||
if "_class_name" in config:
|
||||
return config["_class_name"]
|
||||
elif "architectures" in config:
|
||||
return config["architectures"][0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
@@ -578,18 +588,122 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
}
|
||||
|
||||
|
||||
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
class VAEConfigBase(CheckpointConfigBase):
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
|
||||
|
||||
class VAECheckpointConfig(VAEConfigBase, ModelConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
KEY_PREFIXES: ClassVar = {"encoder.conv_in", "decoder.conv_in"}
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_vae_override = overrides.get("type") is ModelType.VAE
|
||||
is_checkpoint_override = overrides.get("format") is ModelFormat.Checkpoint
|
||||
|
||||
if is_vae_override and is_checkpoint_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_dir():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
if mod.has_keys_starting_with(cls.KEY_PREFIXES):
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
base = cls.get_base_type(mod)
|
||||
return {"base": base}
|
||||
|
||||
@classmethod
|
||||
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
|
||||
for regexp, basetype in [
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
|
||||
]:
|
||||
if re.search(regexp, mod.path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
|
||||
class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
class VAEDiffusersConfig(VAEConfigBase, ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"}
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_vae_override = overrides.get("type") is ModelType.VAE
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
|
||||
if is_vae_override and is_diffusers_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
try:
|
||||
config = cls.get_config(mod)
|
||||
class_name = get_class_name_from_config(config)
|
||||
if class_name in cls.CLASS_NAMES:
|
||||
return MatchCertainty.EXACT
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
config_path = mod.path / "config.json"
|
||||
with open(config_path, "r") as file:
|
||||
return json.load(file)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
base = cls.get_base_type(mod)
|
||||
return {"base": base}
|
||||
|
||||
@classmethod
|
||||
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
if cls._config_looks_like_sdxl(mod):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif cls._name_looks_like_sdxl(mod):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
# We do not support diffusers VAEs for any other base model at this time... YOLO
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
@classmethod
|
||||
def _config_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
|
||||
config = cls.get_config(mod)
|
||||
# Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
@classmethod
|
||||
def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
|
||||
# Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best
|
||||
# we can do is guess based on name.
|
||||
return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE))
|
||||
|
||||
@classmethod
|
||||
def _guess_name(cls, mod: ModelOnDisk) -> str:
|
||||
name = mod.path.name
|
||||
if name == "vae":
|
||||
name = mod.path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
@@ -654,21 +653,6 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
return SchedulerPredictionType.Epsilon
|
||||
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# VAEs of all base types have the same structure, so we wimp out and
|
||||
# guess using the name.
|
||||
for regexp, basetype in [
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
|
||||
]:
|
||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for LoRA checkpoints."""
|
||||
|
||||
@@ -895,36 +879,6 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def _config_looks_like_sdxl(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.model_path.name
|
||||
if name == "vae":
|
||||
name = self.model_path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class T5EncoderFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
@@ -1080,7 +1034,6 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
# Register probe classes
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
|
||||
@@ -1093,7 +1046,6 @@ ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe
|
||||
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
|
||||
@@ -128,3 +128,19 @@ class ModelOnDisk:
|
||||
f"Please specify the intended file using the 'path' argument"
|
||||
)
|
||||
return path
|
||||
|
||||
def has_keys_exact(self, keys: set[str], path: Optional[Path] = None) -> bool:
|
||||
state_dict = self.load_state_dict(path)
|
||||
return keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
|
||||
|
||||
def has_keys_starting_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
|
||||
state_dict = self.load_state_dict(path)
|
||||
return any(
|
||||
any(key.startswith(prefix) for prefix in prefixes) for key in state_dict.keys() if isinstance(key, str)
|
||||
)
|
||||
|
||||
def has_keys_ending_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
|
||||
state_dict = self.load_state_dict(path)
|
||||
return any(
|
||||
any(key.endswith(suffix) for suffix in prefixes) for key in state_dict.keys() if isinstance(key, str)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user