feat(mm): port vae to new API

This commit is contained in:
psychedelicious
2025-09-23 15:17:44 +10:00
parent 37de184198
commit 3dfcf9a869
3 changed files with 134 additions and 52 deletions

View File

@@ -23,6 +23,7 @@ Validation errors will raise an InvalidModelConfigException error.
# pyright: reportIncompatibleVariableOverride=false
import json
import logging
import re
import time
from abc import ABC, abstractmethod
from enum import Enum
@@ -73,6 +74,15 @@ class InvalidModelConfigException(Exception):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]:
if "_class_name" in config:
return config["_class_name"]
elif "architectures" in config:
return config["architectures"][0]
else:
return None
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
@@ -578,18 +588,122 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
}
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
class VAEConfigBase(CheckpointConfigBase):
type: Literal[ModelType.VAE] = ModelType.VAE
class VAECheckpointConfig(VAEConfigBase, ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
KEY_PREFIXES: ClassVar = {"encoder.conv_in", "decoder.conv_in"}
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_vae_override = overrides.get("type") is ModelType.VAE
is_checkpoint_override = overrides.get("format") is ModelFormat.Checkpoint
if is_vae_override and is_checkpoint_override:
return MatchCertainty.OVERRIDE
if mod.path.is_dir():
return MatchCertainty.NEVER
if mod.has_keys_starting_with(cls.KEY_PREFIXES):
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
base = cls.get_base_type(mod)
return {"base": base}
@classmethod
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
]:
if re.search(regexp, mod.path.name, re.IGNORECASE):
return basetype
raise InvalidModelConfigException("Cannot determine base type")
class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase):
class VAEDiffusersConfig(VAEConfigBase, ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"}
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_vae_override = overrides.get("type") is ModelType.VAE
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
if is_vae_override and is_diffusers_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
try:
config = cls.get_config(mod)
class_name = get_class_name_from_config(config)
if class_name in cls.CLASS_NAMES:
return MatchCertainty.EXACT
except Exception:
pass
return MatchCertainty.NEVER
@classmethod
def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]:
config_path = mod.path / "config.json"
with open(config_path, "r") as file:
return json.load(file)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
base = cls.get_base_type(mod)
return {"base": base}
@classmethod
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
if cls._config_looks_like_sdxl(mod):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
return BaseModelType.StableDiffusionXL
else:
# We do not support diffusers VAEs for any other base model at this time... YOLO
return BaseModelType.StableDiffusion1
@classmethod
def _config_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
config = cls.get_config(mod)
# Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
@classmethod
def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
# Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best
# we can do is guess based on name.
return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE))
@classmethod
def _guess_name(cls, mod: ModelOnDisk) -> str:
name = mod.path.name
if name == "vae":
name = mod.path.parent.name
return name
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""

View File

@@ -1,5 +1,4 @@
import json
import re
from pathlib import Path
from typing import Any, Callable, Dict, Literal, Optional, Union
@@ -654,21 +653,6 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return SchedulerPredictionType.Epsilon
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# VAEs of all base types have the same structure, so we wimp out and
# guess using the name.
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
raise InvalidModelConfigException("Cannot determine base type")
class LoRACheckpointProbe(CheckpointProbeBase):
"""Class for LoRA checkpoints."""
@@ -895,36 +879,6 @@ class PipelineFolderProbe(FolderProbeBase):
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.model_path.name
if name == "vae":
name = self.model_path.parent.name
return name
class T5EncoderFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
@@ -1080,7 +1034,6 @@ class T2IAdapterFolderProbe(FolderProbeBase):
# Register probe classes
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
@@ -1093,7 +1046,6 @@ ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)

View File

@@ -128,3 +128,19 @@ class ModelOnDisk:
f"Please specify the intended file using the 'path' argument"
)
return path
def has_keys_exact(self, keys: set[str], path: Optional[Path] = None) -> bool:
state_dict = self.load_state_dict(path)
return keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
def has_keys_starting_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
state_dict = self.load_state_dict(path)
return any(
any(key.startswith(prefix) for prefix in prefixes) for key in state_dict.keys() if isinstance(key, str)
)
def has_keys_ending_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
state_dict = self.load_state_dict(path)
return any(
any(key.endswith(suffix) for suffix in prefixes) for key in state_dict.keys() if isinstance(key, str)
)