From d93d4242f938f569061ce440dcbea46f7c7baa3a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 3 Oct 2025 19:24:29 +1000 Subject: [PATCH] refactor(mm): split configs into separate files --- .../backend/model_manager/configs/__init__.py | 0 .../backend/model_manager/configs/base.py | 243 ++++++ .../model_manager/configs/clip_embed.py | 91 +++ .../model_manager/configs/clip_vision.py | 44 ++ .../model_manager/configs/controlnet.py | 195 +++++ .../model_manager/configs/external_api.py | 0 .../backend/model_manager/configs/factory.py | 340 +++++++++ .../model_manager/configs/flux_redux.py | 40 + .../configs/identification_utils.py | 182 +++++ .../model_manager/configs/ip_adapter.py | 180 +++++ .../model_manager/configs/llava_onevision.py | 42 ++ .../backend/model_manager/configs/lora.py | 323 ++++++++ .../backend/model_manager/configs/main.py | 692 ++++++++++++++++++ .../backend/model_manager/configs/siglip.py | 44 ++ .../backend/model_manager/configs/spandrel.py | 54 ++ .../model_manager/configs/t2i_adapter.py | 79 ++ .../model_manager/configs/t5_encoder.py | 87 +++ .../configs/textual_inversion.py | 156 ++++ .../backend/model_manager/configs/unknown.py | 24 + invokeai/backend/model_manager/configs/vae.py | 166 +++++ 20 files changed, 2982 insertions(+) create mode 100644 invokeai/backend/model_manager/configs/__init__.py create mode 100644 invokeai/backend/model_manager/configs/base.py create mode 100644 invokeai/backend/model_manager/configs/clip_embed.py create mode 100644 invokeai/backend/model_manager/configs/clip_vision.py create mode 100644 invokeai/backend/model_manager/configs/controlnet.py create mode 100644 invokeai/backend/model_manager/configs/external_api.py create mode 100644 invokeai/backend/model_manager/configs/factory.py create mode 100644 invokeai/backend/model_manager/configs/flux_redux.py create mode 100644 invokeai/backend/model_manager/configs/identification_utils.py create mode 100644 invokeai/backend/model_manager/configs/ip_adapter.py create mode 100644 invokeai/backend/model_manager/configs/llava_onevision.py create mode 100644 invokeai/backend/model_manager/configs/lora.py create mode 100644 invokeai/backend/model_manager/configs/main.py create mode 100644 invokeai/backend/model_manager/configs/siglip.py create mode 100644 invokeai/backend/model_manager/configs/spandrel.py create mode 100644 invokeai/backend/model_manager/configs/t2i_adapter.py create mode 100644 invokeai/backend/model_manager/configs/t5_encoder.py create mode 100644 invokeai/backend/model_manager/configs/textual_inversion.py create mode 100644 invokeai/backend/model_manager/configs/unknown.py create mode 100644 invokeai/backend/model_manager/configs/vae.py diff --git a/invokeai/backend/model_manager/configs/__init__.py b/invokeai/backend/model_manager/configs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/model_manager/configs/base.py b/invokeai/backend/model_manager/configs/base.py new file mode 100644 index 0000000000..9e997e4bcd --- /dev/null +++ b/invokeai/backend/model_manager/configs/base.py @@ -0,0 +1,243 @@ +from abc import ABC, abstractmethod +from enum import Enum +from inspect import isabstract +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Self, + Type, +) + +from pydantic import BaseModel, ConfigDict, Field, Tag +from pydantic_core import PydanticUndefined + +from invokeai.app.util.misc import uuid_string +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + AnyVariant, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelSourceType, + ModelType, +) + +if TYPE_CHECKING: + pass + + +class Config_Base(ABC, BaseModel): + """ + Abstract base class for model configurations. A model config describes a specific combination of model base, type and + format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format + would have base=sd-1, type=main, format=checkpoint. + + To create a new config type, inherit from this class and implement its interface: + - Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be + called during model installation to determine the correct config class for a model. + - Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A + default must be provided for each of these fields. + + If multiple combinations of base, type and format need to be supported, create a separate subclass for each. + + See MinimalConfigExample in test_model_probe.py for an example implementation. + """ + + # These fields are common to all model configs. + + key: str = Field( + default_factory=uuid_string, + description="A unique key for this model.", + ) + hash: str = Field( + description="The hash of the model file(s).", + ) + path: str = Field( + description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.", + ) + file_size: int = Field( + description="The size of the model in bytes.", + ) + name: str = Field( + description="Name of the model.", + ) + description: str | None = Field( + default=None, + description="Model description", + ) + source: str = Field( + description="The original source of the model (path, URL or repo_id).", + ) + source_type: ModelSourceType = Field( + description="The type of source", + ) + source_api_response: str | None = Field( + default=None, + description="The original API response from the source, as stringified JSON.", + ) + cover_image: str | None = Field( + default=None, + description="Url for image to preview model", + ) + usage_info: str | None = Field( + default=None, + description="Usage information for this model", + ) + + CONFIG_CLASSES: ClassVar[set[Type["Config_Base"]]] = set() + """Set of all non-abstract subclasses of Config_Base, for use during model probing. In other words, this is the set + of all known model config types.""" + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + json_schema_mode_override="serialization", + ) + + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Register non-abstract subclasses so we can iterate over them later during model probing. Note that + # isabstract() will return False if the class does not have any abstract methods, even if it inherits from ABC. + # We must check for ABC lest we unintentionally register some abstract model config classes. + if not isabstract(cls) and ABC not in cls.__bases__: + cls.CONFIG_CLASSES.add(cls) + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs): + # Ensure that model configs define 'base', 'type' and 'format' fields and provide defaults for them. Each + # subclass is expected to represent a single combination of base, type and format. + # + # This pydantic dunder method is called after the pydantic model for a class is created. The normal + # __init_subclass__ is too early to do this check. + for name in ("type", "base", "format"): + if name not in cls.model_fields: + raise NotImplementedError(f"{cls.__name__} must define a '{name}' field") + if cls.model_fields[name].default is PydanticUndefined: + raise NotImplementedError(f"{cls.__name__} must define a default for the '{name}' field") + + @classmethod + def get_tag(cls) -> Tag: + """Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized, + pydantic uses the tag to determine which subclass to instantiate. + + The tag is a dot-separated string of the type, format, base and variant (if applicable). + """ + tag_strings: list[str] = [] + for name in ("type", "format", "base", "variant"): + if field := cls.model_fields.get(name): + # The check in __pydantic_init_subclass__ ensures that type, format and base are always present with + # defaults. variant does not require a default, but if it has one, we need to add it to the tag. We can + # check for the presence of a default by seeing if it's not PydanticUndefined, a sentinel value used by + # pydantic to indicate that no default was provided. + if field.default is not PydanticUndefined: + # We expect each of these fields has an Enum for its default; we want the value of the enum. + tag_strings.append(field.default.value) + return Tag(".".join(tag_strings)) + + @staticmethod + def get_model_discriminator_value(v: Any) -> str: + """Computes the discriminator value for a model config discriminated union.""" + # This is called by pydantic during deserialization and serialization to determine which model the data + # represents. It can get either a dict (during deserialization) or an instance of a Config_Base subclass + # (during serialization). + # + # See: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator + if isinstance(v, Config_Base): + # We have an instance of a ModelConfigBase subclass - use its tag directly. + return v.get_tag().tag + if isinstance(v, dict): + # We have a dict - attempt to compute a tag from its fields. + tag_strings: list[str] = [] + if type_ := v.get("type"): + if isinstance(type_, Enum): + type_ = str(type_.value) + elif not isinstance(type_, str): + raise TypeError("Model config dict 'type' field must be a string or Enum") + tag_strings.append(type_) + + if format_ := v.get("format"): + if isinstance(format_, Enum): + format_ = str(format_.value) + elif not isinstance(format_, str): + raise TypeError("Model config dict 'format' field must be a string or Enum") + tag_strings.append(format_) + + if base_ := v.get("base"): + if isinstance(base_, Enum): + base_ = str(base_.value) + elif not isinstance(base_, str): + raise TypeError("Model config dict 'base' field must be a string or Enum") + tag_strings.append(base_) + + # Special case: CLIP Embed models also need the variant to distinguish them. + if ( + type_ == ModelType.CLIPEmbed.value + and format_ == ModelFormat.Diffusers.value + and base_ == BaseModelType.Any.value + ): + if variant_ := v.get("variant"): + if isinstance(variant_, Enum): + variant_ = variant_.value + elif not isinstance(variant_, str): + raise TypeError("Model config dict 'variant' field must be a string or Enum") + tag_strings.append(variant_) + else: + raise ValueError("CLIP Embed model config dict must include a 'variant' field") + + return ".".join(tag_strings) + else: + raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance") + + @abstractmethod + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + """Given the model on disk and any override fields, attempt to construct an instance of this config class. + + This method serves to identify whether the model on disk matches this config class, and if so, to extract any + additional metadata needed to instantiate the config. + + Implementations should raise a NotAMatchError if the model does not match this config class.""" + raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}") + + +class Checkpoint_Config_Base(ABC, BaseModel): + """Base class for checkpoint-style models.""" + + config_path: str | None = Field( + description="Path to the config for this model, if any.", + default=None, + ) + + +class Diffusers_Config_Base(ABC, BaseModel): + """Base class for diffusers-style models.""" + + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + repo_variant: ModelRepoVariant = Field(ModelRepoVariant.Default) + + @classmethod + def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant: + # get all files ending in .bin or .safetensors + weight_files = list(mod.path.glob("**/*.safetensors")) + weight_files.extend(list(mod.path.glob("**/*.bin"))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OpenVINO + if "flax_model" in x.name: + return ModelRepoVariant.Flax + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.Default + + +class SubmodelDefinition(BaseModel): + path_or_prefix: str + model_type: ModelType + variant: AnyVariant | None = None + + model_config = ConfigDict(protected_namespaces=()) diff --git a/invokeai/backend/model_manager/configs/clip_embed.py b/invokeai/backend/model_manager/configs/clip_embed.py new file mode 100644 index 0000000000..4bb24a0a63 --- /dev/null +++ b/invokeai/backend/model_manager/configs/clip_embed.py @@ -0,0 +1,91 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ClipVariantType, + ModelFormat, + ModelType, +) + + +def get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None: + try: + hidden_size = config.get("hidden_size") + match hidden_size: + case 1280: + return ClipVariantType.G + case 768: + return ClipVariantType.L + case _: + return None + except Exception: + return None + + +class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base): + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + { + mod.path / "config.json", + mod.path / "text_encoder" / "config.json", + }, + { + "CLIPModel", + "CLIPTextModel", + "CLIPTextModelWithProjection", + }, + ) + + cls._validate_variant(mod) + + return cls(**override_fields) + + @classmethod + def _validate_variant(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model variant does not match this config class.""" + expected_variant = cls.model_fields["variant"].default + config = get_config_dict_or_raise( + { + mod.path / "config.json", + mod.path / "text_encoder" / "config.json", + }, + ) + recognized_variant = get_clip_variant_type_from_config(config) + + if recognized_variant is None: + raise NotAMatchError("unable to determine CLIP variant from config") + + if expected_variant is not recognized_variant: + raise NotAMatchError(f"variant is {recognized_variant}, not {expected_variant}") + + +class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): + variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G) + + +class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): + variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L) diff --git a/invokeai/backend/model_manager/configs/clip_vision.py b/invokeai/backend/model_manager/configs/clip_vision.py new file mode 100644 index 0000000000..eb17fb9e33 --- /dev/null +++ b/invokeai/backend/model_manager/configs/clip_vision.py @@ -0,0 +1,44 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + common_config_paths, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base): + """Model config for CLIPVision.""" + + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "CLIPVisionModelWithProjection", + }, + ) + + return cls(**override_fields) diff --git a/invokeai/backend/model_manager/configs/controlnet.py b/invokeai/backend/model_manager/configs/controlnet.py new file mode 100644 index 0000000000..7db782a862 --- /dev/null +++ b/invokeai/backend/model_manager/configs/controlnet.py @@ -0,0 +1,195 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.flux.controlnet.state_dict_utils import ( + is_state_dict_instantx_controlnet, + is_state_dict_xlabs_controlnet, +) +from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, + state_dict_has_any_keys_starting_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base): + """Model config for ControlNet models (diffusers version).""" + + type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + default_settings: ControlAdapterDefaultSettings | None = Field(None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "ControlNetModel", + "FluxControlNetModel", + }, + ) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + config_dict = get_config_dict_or_raise(common_config_paths(mod.path)) + + if config_dict.get("_class_name") == "FluxControlNetModel": + return BaseModelType.Flux + + dimension = config_dict.get("cross_attention_dim") + + match dimension: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + # No obvious way to distinguish between sd2-base and sd2-768, but we don't really differentiate them + # anyway. + return BaseModelType.StableDiffusion2 + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized cross_attention_dim {dimension}") + + +class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base): + """Model config for ControlNet models (diffusers version).""" + + type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + default_settings: ControlAdapterDefaultSettings | None = Field(None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_controlnet(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None: + if not state_dict_has_any_keys_starting_with( + mod.load_state_dict(), + { + "controlnet", + "control_model", + "input_blocks", + # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks." + # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors + # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with + # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so + # delicate. + "controlnet_blocks", + }, + ): + raise NotAMatchError("state dict does not look like a ControlNet checkpoint") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + if is_state_dict_xlabs_controlnet(state_dict) or is_state_dict_instantx_controlnet(state_dict): + # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing + # get_format()? + return BaseModelType.Flux + + for key in ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "controlnet_mid_block.bias", + "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + ): + if key not in state_dict: + continue + width = state_dict[key].shape[-1] + match width: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 2048: + return BaseModelType.StableDiffusionXL + case 1280: + return BaseModelType.StableDiffusionXL + case _: + pass + + raise NotAMatchError("unable to determine base type from state dict") + + +class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) diff --git a/invokeai/backend/model_manager/configs/external_api.py b/invokeai/backend/model_manager/configs/external_api.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py new file mode 100644 index 0000000000..27b6f252f1 --- /dev/null +++ b/invokeai/backend/model_manager/configs/factory.py @@ -0,0 +1,340 @@ +import logging +from pathlib import Path +from typing import ( + Union, +) + +from pydantic import Discriminator, TypeAdapter, ValidationError +from typing_extensions import Annotated, Any + +from invokeai.app.services.config.config_default import get_config +from invokeai.app.util.misc import uuid_string +from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_L_Config +from invokeai.backend.model_manager.configs.clip_vision import CLIPVision_Diffusers_Config +from invokeai.backend.model_manager.configs.controlnet import ( + ControlNet_Checkpoint_FLUX_Config, + ControlNet_Checkpoint_SD1_Config, + ControlNet_Checkpoint_SD2_Config, + ControlNet_Checkpoint_SDXL_Config, + ControlNet_Diffusers_FLUX_Config, + ControlNet_Diffusers_SD1_Config, + ControlNet_Diffusers_SD2_Config, + ControlNet_Diffusers_SDXL_Config, +) +from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config +from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError +from invokeai.backend.model_manager.configs.ip_adapter import ( + IPAdapter_Checkpoint_FLUX_Config, + IPAdapter_Checkpoint_SD1_Config, + IPAdapter_Checkpoint_SD2_Config, + IPAdapter_Checkpoint_SDXL_Config, + IPAdapter_InvokeAI_SD1_Config, + IPAdapter_InvokeAI_SD2_Config, + IPAdapter_InvokeAI_SDXL_Config, +) +from invokeai.backend.model_manager.configs.llava_onevision import LlavaOnevision_Diffusers_Config +from invokeai.backend.model_manager.configs.lora import ( + ControlLoRA_LyCORIS_FLUX_Config, + LoRA_Diffusers_FLUX_Config, + LoRA_Diffusers_SD1_Config, + LoRA_Diffusers_SD2_Config, + LoRA_Diffusers_SDXL_Config, + LoRA_LyCORIS_FLUX_Config, + LoRA_LyCORIS_SD1_Config, + LoRA_LyCORIS_SD2_Config, + LoRA_LyCORIS_SDXL_Config, + LoRA_OMI_FLUX_Config, + LoRA_OMI_SDXL_Config, +) +from invokeai.backend.model_manager.configs.main import ( + Main_BnBNF4_FLUX_Config, + Main_Checkpoint_FLUX_Config, + Main_Checkpoint_SD1_Config, + Main_Checkpoint_SD2_Config, + Main_Checkpoint_SDXL_Config, + Main_Checkpoint_SDXLRefiner_Config, + Main_Diffusers_CogView4_Config, + Main_Diffusers_SD1_Config, + Main_Diffusers_SD2_Config, + Main_Diffusers_SD3_Config, + Main_Diffusers_SDXL_Config, + Main_Diffusers_SDXLRefiner_Config, + Main_ExternalAPI_ChatGPT4o_Config, + Main_ExternalAPI_FluxKontext_Config, + Main_ExternalAPI_Gemini2_5_Config, + Main_ExternalAPI_Imagen3_Config, + Main_ExternalAPI_Imagen4_Config, + Main_GGUF_FLUX_Config, + Video_ExternalAPI_Runway_Config, + Video_ExternalAPI_Veo3_Config, +) +from invokeai.backend.model_manager.configs.siglip import SigLIP_Diffusers_Config +from invokeai.backend.model_manager.configs.spandrel import Spandrel_Checkpoint_Config +from invokeai.backend.model_manager.configs.t2i_adapter import ( + T2IAdapter_Diffusers_SD1_Config, + T2IAdapter_Diffusers_SDXL_Config, +) +from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config +from invokeai.backend.model_manager.configs.textual_inversion import ( + TI_File_SD1_Config, + TI_File_SD2_Config, + TI_File_SDXL_Config, + TI_Folder_SD1_Config, + TI_Folder_SD2_Config, + TI_Folder_SDXL_Config, +) +from invokeai.backend.model_manager.configs.unknown import Unknown_Config +from invokeai.backend.model_manager.configs.vae import ( + VAE_Checkpoint_FLUX_Config, + VAE_Checkpoint_SD1_Config, + VAE_Checkpoint_SD2_Config, + VAE_Checkpoint_SDXL_Config, + VAE_Diffusers_SD1_Config, + VAE_Diffusers_SDXL_Config, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelSourceType, + ModelType, + variant_type_adapter, +) + +logger = logging.getLogger(__name__) +app_config = get_config() + + +# The types are listed explicitly because IDEs/LSPs can't identify the correct types +# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes +AnyModelConfig = Annotated[ + Union[ + # Main (Pipeline) - diffusers format + Annotated[Main_Diffusers_SD1_Config, Main_Diffusers_SD1_Config.get_tag()], + Annotated[Main_Diffusers_SD2_Config, Main_Diffusers_SD2_Config.get_tag()], + Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()], + Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()], + Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()], + Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()], + # Main (Pipeline) - checkpoint format + Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()], + Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()], + Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()], + Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()], + Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()], + # Main (Pipeline) - quantized formats + Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()], + Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()], + # VAE - checkpoint format + Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()], + Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()], + Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()], + Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()], + # VAE - diffusers format + Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()], + Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()], + # ControlNet - checkpoint format + Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()], + Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()], + Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()], + Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()], + # ControlNet - diffusers format + Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()], + Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()], + Annotated[ControlNet_Diffusers_SDXL_Config, ControlNet_Diffusers_SDXL_Config.get_tag()], + Annotated[ControlNet_Diffusers_FLUX_Config, ControlNet_Diffusers_FLUX_Config.get_tag()], + # LoRA - LyCORIS format + Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()], + Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()], + Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()], + Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()], + # LoRA - OMI format + Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()], + Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()], + # LoRA - diffusers format + Annotated[LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SD1_Config.get_tag()], + Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()], + Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()], + Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()], + # ControlLoRA - diffusers format + Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()], + # T5 Encoder - all formats + Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()], + Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()], + # TI - file format + Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()], + Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()], + Annotated[TI_File_SDXL_Config, TI_File_SDXL_Config.get_tag()], + # TI - folder format + Annotated[TI_Folder_SD1_Config, TI_Folder_SD1_Config.get_tag()], + Annotated[TI_Folder_SD2_Config, TI_Folder_SD2_Config.get_tag()], + Annotated[TI_Folder_SDXL_Config, TI_Folder_SDXL_Config.get_tag()], + # IP Adapter - InvokeAI format + Annotated[IPAdapter_InvokeAI_SD1_Config, IPAdapter_InvokeAI_SD1_Config.get_tag()], + Annotated[IPAdapter_InvokeAI_SD2_Config, IPAdapter_InvokeAI_SD2_Config.get_tag()], + Annotated[IPAdapter_InvokeAI_SDXL_Config, IPAdapter_InvokeAI_SDXL_Config.get_tag()], + # IP Adapter - checkpoint format + Annotated[IPAdapter_Checkpoint_SD1_Config, IPAdapter_Checkpoint_SD1_Config.get_tag()], + Annotated[IPAdapter_Checkpoint_SD2_Config, IPAdapter_Checkpoint_SD2_Config.get_tag()], + Annotated[IPAdapter_Checkpoint_SDXL_Config, IPAdapter_Checkpoint_SDXL_Config.get_tag()], + Annotated[IPAdapter_Checkpoint_FLUX_Config, IPAdapter_Checkpoint_FLUX_Config.get_tag()], + # T2I Adapter - diffusers format + Annotated[T2IAdapter_Diffusers_SD1_Config, T2IAdapter_Diffusers_SD1_Config.get_tag()], + Annotated[T2IAdapter_Diffusers_SDXL_Config, T2IAdapter_Diffusers_SDXL_Config.get_tag()], + # Misc models + Annotated[Spandrel_Checkpoint_Config, Spandrel_Checkpoint_Config.get_tag()], + Annotated[CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_G_Config.get_tag()], + Annotated[CLIPEmbed_Diffusers_L_Config, CLIPEmbed_Diffusers_L_Config.get_tag()], + Annotated[CLIPVision_Diffusers_Config, CLIPVision_Diffusers_Config.get_tag()], + Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()], + Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()], + Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()], + # Main - external API + Annotated[Main_ExternalAPI_ChatGPT4o_Config, Main_ExternalAPI_ChatGPT4o_Config.get_tag()], + Annotated[Main_ExternalAPI_Gemini2_5_Config, Main_ExternalAPI_Gemini2_5_Config.get_tag()], + Annotated[Main_ExternalAPI_Imagen3_Config, Main_ExternalAPI_Imagen3_Config.get_tag()], + Annotated[Main_ExternalAPI_Imagen4_Config, Main_ExternalAPI_Imagen4_Config.get_tag()], + Annotated[Main_ExternalAPI_FluxKontext_Config, Main_ExternalAPI_FluxKontext_Config.get_tag()], + # Video - external API + Annotated[Video_ExternalAPI_Veo3_Config, Video_ExternalAPI_Veo3_Config.get_tag()], + Annotated[Video_ExternalAPI_Runway_Config, Video_ExternalAPI_Runway_Config.get_tag()], + # Unknown model (fallback) + Annotated[Unknown_Config, Unknown_Config.get_tag()], + ], + Discriminator(Config_Base.get_model_discriminator_value), +] + +AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) + + +class ModelConfigFactory: + @staticmethod + def from_dict(fields: dict[str, Any]) -> AnyModelConfig: + """Return the appropriate config object from raw dict values.""" + model = AnyModelConfigValidator.validate_python(fields) + return model + + @staticmethod + def build_common_fields( + mod: ModelOnDisk, + override_fields: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Builds the common fields for all model configs. + + Args: + mod: The model on disk to extract fields from. + overrides: A optional dictionary of fields to override. These fields will take precedence over the values + extracted from the model on disk. + + - Casts string fields to their Enum types. + - Does not validate the fields against the model config schema. + """ + + _overrides: dict[str, Any] = override_fields or {} + fields: dict[str, Any] = {} + + if "type" in _overrides: + fields["type"] = ModelType(_overrides["type"]) + + if "format" in _overrides: + fields["format"] = ModelFormat(_overrides["format"]) + + if "base" in _overrides: + fields["base"] = BaseModelType(_overrides["base"]) + + if "source_type" in _overrides: + fields["source_type"] = ModelSourceType(_overrides["source_type"]) + + if "variant" in _overrides: + fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"]) + + fields["path"] = mod.path.as_posix() + fields["source"] = _overrides.get("source") or fields["path"] + fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path + fields["name"] = _overrides.get("name") or mod.name + fields["hash"] = _overrides.get("hash") or mod.hash() + fields["key"] = _overrides.get("key") or uuid_string() + fields["description"] = _overrides.get("description") + fields["file_size"] = _overrides.get("file_size") or mod.size() + + return fields + + @staticmethod + def from_model_on_disk( + mod: str | Path | ModelOnDisk, + override_fields: dict[str, Any] | None = None, + hash_algo: HASHING_ALGORITHMS = "blake3_single", + ) -> AnyModelConfig: + """ + Returns the best matching ModelConfig instance from a model's file/folder path. + Raises InvalidModelConfigException if no valid configuration is found. + Created to deprecate ModelProbe.probe + """ + if isinstance(mod, Path | str): + mod = ModelOnDisk(Path(mod), hash_algo) + + # We will always need these fields to build any model config. + fields = ModelConfigFactory.build_common_fields(mod, override_fields) + + # Store results as a mapping of config class to either an instance of that class or an exception + # that was raised when trying to build it. + results: dict[str, AnyModelConfig | Exception] = {} + + # Try to build an instance of each model config class that uses the classify API. + # Each class will either return an instance of itself or raise NotAMatch if it doesn't match. + # Other exceptions may be raised if something unexpected happens during matching or building. + for config_class in Config_Base.CONFIG_CLASSES: + class_name = config_class.__name__ + try: + instance = config_class.from_model_on_disk(mod, fields) + # Technically, from_model_on_disk returns a Config_Base, but in practice it will always be a member of + # the AnyModelConfig union. + results[class_name] = instance # type: ignore + except NotAMatchError as e: + results[class_name] = e + logger.debug(f"No match for {config_class.__name__} on model {mod.name}") + except ValidationError as e: + # This means the model matched, but we couldn't create the pydantic model instance for the config. + # Maybe invalid overrides were provided? + results[class_name] = e + logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}") + except Exception as e: + results[class_name] = e + logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}") + + matches = [r for r in results.values() if isinstance(r, Config_Base)] + + if not matches and app_config.allow_unknown_models: + logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config") + return Unknown_Config(**fields) + + if len(matches) > 1: + # We have multiple matches, in which case at most 1 is correct. We need to pick one. + # + # Known cases: + # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model. + # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with + # a config.json file. Prefer the main model. + + # Sort the matching according to known special cases. + def sort_key(m: AnyModelConfig) -> int: + match m.type: + case ModelType.Main: + return 0 + case ModelType.LoRA: + return 1 + case ModelType.CLIPEmbed: + return 2 + case _: + return 3 + + matches.sort(key=sort_key) + logger.warning( + f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}." + ) + + instance = matches[0] + logger.info(f"Model {mod.name} classified as {type(instance).__name__}") + return instance diff --git a/invokeai/backend/model_manager/configs/flux_redux.py b/invokeai/backend/model_manager/configs/flux_redux.py new file mode 100644 index 0000000000..6eb76116fb --- /dev/null +++ b/invokeai/backend/model_manager/configs/flux_redux.py @@ -0,0 +1,40 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_file, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class FLUXRedux_Checkpoint_Config(Config_Base): + """Model config for FLUX Tools Redux model.""" + + type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + if not is_state_dict_likely_flux_redux(mod.load_state_dict()): + raise NotAMatchError("model does not match FLUX Tools Redux heuristics") + + return cls(**override_fields) diff --git a/invokeai/backend/model_manager/configs/identification_utils.py b/invokeai/backend/model_manager/configs/identification_utils.py new file mode 100644 index 0000000000..58d545cb64 --- /dev/null +++ b/invokeai/backend/model_manager/configs/identification_utils.py @@ -0,0 +1,182 @@ +import json +from functools import cache +from pathlib import Path + +from pydantic import BaseModel, ValidationError +from pydantic_core import CoreSchema, SchemaValidator +from typing_extensions import Any + +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk + + +class NotAMatchError(Exception): + """Exception for when a model does not match a config class. + + Args: + reason: The reason why the model did not match. + """ + + def __init__(self, reason: str): + super().__init__(reason) + + +def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]: + paths_to_check = config_path if isinstance(config_path, set) else {config_path} + + problems: dict[Path, str] = {} + + for p in paths_to_check: + if not p.exists(): + problems[p] = "file does not exist" + continue + + try: + with open(p, "r") as file: + config = json.load(file) + + return config + except Exception as e: + problems[p] = str(e) + continue + + raise NotAMatchError(f"unable to load config file(s): {problems}") + + +def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> str: + """Load the diffusers/transformers model config file and return the class name. + + Raises: + NotAMatch if the config file is missing or does not contain a valid class name. + """ + + config = get_config_dict_or_raise(config_path) + + try: + if "_class_name" in config: + # This is a diffusers-style config + config_class_name = config["_class_name"] + elif "architectures" in config: + # This is a transformers-style config + config_class_name = config["architectures"][0] + else: + raise ValueError("missing _class_name or architectures field") + except Exception as e: + raise NotAMatchError(f"unable to determine class name from config file: {config_path}") from e + + if not isinstance(config_class_name, str): + raise NotAMatchError(f"_class_name or architectures field is not a string: {config_class_name}") + + return config_class_name + + +def raise_for_class_name(config_path: Path | set[Path], expected: set[str]) -> None: + """Get the class name from the config file and raise NotAMatch if it is not in the expected set. + + Args: + config_path: The path to the config file. + expected: The expected class names. + + Raises: + NotAMatch if the class name is not in the expected set. + """ + + class_name = get_class_name_from_config_dict_or_raise(config_path) + if class_name not in expected: + raise NotAMatchError(f"invalid class name from config: {class_name}") + + +def raise_for_override_fields(candidate_config_class: type[BaseModel], override_fields: dict[str, Any]) -> None: + """Check if the provided override fields are valid for the config class using pydantic. + + For example, if the candidate config class has a field "base" of type Literal[BaseModelType.StableDiffusion1], and + the override fields contain "base": BaseModelType.Flux, this function will raise NotAMatch. + + Args: + candidate_config_class: The config class that is being tested. + override_fields: The override fields provided by the user. + + Raises: + NotAMatch if any override field is invalid for the config class. + """ + for field_name, override_value in override_fields.items(): + if field_name not in candidate_config_class.model_fields: + raise NotAMatchError(f"unknown override field: {field_name}") + try: + PydanticFieldValidator.validate_field(candidate_config_class, field_name, override_value) + except ValidationError as e: + raise NotAMatchError(f"invalid override for field '{field_name}': {e}") from e + + +def raise_if_not_file(mod: ModelOnDisk) -> None: + """Raise NotAMatch if the model path is not a file.""" + if not mod.path.is_file(): + raise NotAMatchError("model path is not a file") + + +def raise_if_not_dir(mod: ModelOnDisk) -> None: + """Raise NotAMatch if the model path is not a directory.""" + if not mod.path.is_dir(): + raise NotAMatchError("model path is not a directory") + + +def state_dict_has_any_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool: + """Returns true if the state dict has any of the specified keys.""" + _keys = {keys} if isinstance(keys, str) else keys + return any(key in state_dict for key in _keys) + + +def state_dict_has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool: + """Returns true if the state dict has any keys starting with any of the specified prefixes.""" + _prefixes = {prefixes} if isinstance(prefixes, str) else prefixes + return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)) + + +def state_dict_has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool: + """Returns true if the state dict has any keys ending with any of the specified suffixes.""" + _suffixes = {suffixes} if isinstance(suffixes, str) else suffixes + return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)) + + +def common_config_paths(path: Path) -> set[Path]: + """Returns common config file paths for models stored in directories.""" + return {path / "config.json", path / "model_index.json"} + + +class PydanticFieldValidator: + """Utility class for validating individual fields of a Pydantic model without instantiating the whole model. + + See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144 + """ + + @staticmethod + def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema: + """Find the Pydantic core schema for a specific field in a model.""" + schema: CoreSchema = model.__pydantic_core_schema__.copy() + # we shallow copied, be careful not to mutate the original schema! + + assert schema["type"] in ["definitions", "model"] + + # find the field schema + field_schema = schema["schema"] # type: ignore + while "fields" not in field_schema: + field_schema = field_schema["schema"] # type: ignore + + field_schema = field_schema["fields"][field_name]["schema"] # type: ignore + + # if the original schema is a definition schema, replace the model schema with the field schema + if schema["type"] == "definitions": + schema["schema"] = field_schema + return schema + else: + return field_schema + + @cache + @staticmethod + def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator: + """Get a SchemaValidator for a specific field in a model.""" + return SchemaValidator(PydanticFieldValidator.find_field_schema(model, field_name)) + + @staticmethod + def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any: + """Validate a value for a specific field in a model.""" + return PydanticFieldValidator.get_validator(model, field_name).validate_python(value) diff --git a/invokeai/backend/model_manager/configs/ip_adapter.py b/invokeai/backend/model_manager/configs/ip_adapter.py new file mode 100644 index 0000000000..dc6f80d922 --- /dev/null +++ b/invokeai/backend/model_manager/configs/ip_adapter.py @@ -0,0 +1,180 @@ +from abc import ABC +from typing import ( + Literal, + Self, +) + +from pydantic import BaseModel, Field +from typing_extensions import Any + +from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, + state_dict_has_any_keys_starting_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class IPAdapter_Config_Base(ABC, BaseModel): + type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter) + + +class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base): + """Model config for IP Adapter diffusers format models.""" + + format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI) + + # TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long + # time. Need to go through the history to make sure I'm understanding this fully. + image_encoder_model_id: str = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_has_weights_file(mod) + + cls._validate_has_image_encoder_metadata_file(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None: + weights_file = mod.path / "ip_adapter.bin" + if not weights_file.exists(): + raise NotAMatchError("missing ip_adapter.bin weights file") + + @classmethod + def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None: + image_encoder_metadata_file = mod.path / "image_encoder.txt" + if not image_encoder_metadata_file.exists(): + raise NotAMatchError("missing image_encoder.txt metadata file") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + try: + cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] + except Exception as e: + raise NotAMatchError(f"unable to determine cross attention dimension: {e}") from e + + match cross_attention_dim: + case 1280: + return BaseModelType.StableDiffusionXL + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case _: + raise NotAMatchError(f"unrecognized cross attention dimension {cross_attention_dim}") + + +class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base): + """Model config for IP Adapter checkpoint format models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_ip_adapter(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None: + if not state_dict_has_any_keys_starting_with( + mod.load_state_dict(), + { + "image_proj.", + "ip_adapter.", + # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.". + "ip_adapter_proj_model.", + }, + ): + raise NotAMatchError("model does not match Checkpoint IP Adapter heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + if is_state_dict_xlabs_ip_adapter(state_dict): + return BaseModelType.Flux + + try: + cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1] + except Exception as e: + raise NotAMatchError(f"unable to determine cross attention dimension: {e}") from e + + match cross_attention_dim: + case 1280: + return BaseModelType.StableDiffusionXL + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case _: + raise NotAMatchError(f"unrecognized cross attention dimension {cross_attention_dim}") + + +class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) diff --git a/invokeai/backend/model_manager/configs/llava_onevision.py b/invokeai/backend/model_manager/configs/llava_onevision.py new file mode 100644 index 0000000000..c6ceb43ca9 --- /dev/null +++ b/invokeai/backend/model_manager/configs/llava_onevision.py @@ -0,0 +1,42 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + common_config_paths, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelType, +) + + +class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base): + """Model config for Llava Onevision models.""" + + type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision) + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "LlavaOnevisionForConditionalGeneration", + }, + ) + + return cls(**override_fields) diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py new file mode 100644 index 0000000000..512137b4c3 --- /dev/null +++ b/invokeai/backend/model_manager/configs/lora.py @@ -0,0 +1,323 @@ +from abc import ABC +from pathlib import Path +from typing import ( + Any, + Literal, + Self, +) + +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Any + +from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings +from invokeai.backend.model_manager.configs.base import ( + Config_Base, +) +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, + state_dict_has_any_keys_ending_with, + state_dict_has_any_keys_starting_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + FluxLoRAFormat, + ModelFormat, + ModelType, +) +from invokeai.backend.model_manager.util.model_util import lora_token_vector_length +from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control + + +class LoraModelDefaultSettings(BaseModel): + weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model") + model_config = ConfigDict(extra="forbid") + + +class LoRA_Config_Base(ABC, BaseModel): + """Base class for LoRA models.""" + + type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) + trigger_phrases: set[str] | None = Field( + default=None, + description="Set of trigger phrases for this model", + ) + default_settings: LoraModelDefaultSettings | None = Field( + default=None, + description="Default settings for this model", + ) + + +def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None: + # TODO(psyche): Moving this import to the function to avoid circular imports. Refactor later. + from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict + + state_dict = mod.load_state_dict(mod.path) + value = flux_format_from_state_dict(state_dict, mod.metadata()) + return value + + +class LoRA_OMI_Config_Base(LoRA_Config_Base): + format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_omi_lora(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_omi_lora(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model metadata does not look like an OMI LoRA.""" + flux_format = _get_flux_lora_format(mod) + if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatchError("model looks like ControlLoRA or Diffusers LoRA") + + metadata = mod.metadata() + + metadata_looks_like_omi_lora = ( + bool(metadata.get("modelspec.sai_model_spec")) + and metadata.get("ot_branch") == "omi_format" + and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora" + ) + + if not metadata_looks_like_omi_lora: + raise NotAMatchError("metadata does not look like OMI LoRA") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]: + metadata = mod.metadata() + architecture = metadata["modelspec.architecture"] + + if architecture == stable_diffusion_xl_1_lora: + return BaseModelType.StableDiffusionXL + elif architecture == flux_dev_1_lora: + return BaseModelType.Flux + else: + raise NotAMatchError(f"unrecognised/unsupported architecture for OMI LoRA: {architecture}") + + +class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class LoRA_LyCORIS_Config_Base(LoRA_Config_Base): + """Model config for LoRA/Lycoris models.""" + + type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) + format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_lora(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: + # First rule out ControlLoRA and Diffusers LoRA + flux_format = _get_flux_lora_format(mod) + if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatchError("model looks like ControlLoRA or Diffusers LoRA") + + # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. + # Some main models have these keys, likely due to the creator merging in a LoRA. + has_key_with_lora_prefix = state_dict_has_any_keys_starting_with( + mod.load_state_dict(), + { + "lora_te_", + "lora_unet_", + "lora_te1_", + "lora_te2_", + "lora_transformer_", + }, + ) + + has_key_with_lora_suffix = state_dict_has_any_keys_ending_with( + mod.load_state_dict(), + { + "to_k_lora.up.weight", + "to_q_lora.down.weight", + "lora_A.weight", + "lora_B.weight", + }, + ) + + if not has_key_with_lora_prefix and not has_key_with_lora_suffix: + raise NotAMatchError("model does not match LyCORIS LoRA heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + if _get_flux_lora_format(mod): + return BaseModelType.Flux + + state_dict = mod.load_state_dict() + # If we've gotten here, we assume that the model is a Stable Diffusion model + token_vector_length = lora_token_vector_length(state_dict) + if token_vector_length == 768: + return BaseModelType.StableDiffusion1 + elif token_vector_length == 1024: + return BaseModelType.StableDiffusion2 + elif token_vector_length == 1280: + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 + elif token_vector_length == 2048: + return BaseModelType.StableDiffusionXL + else: + raise NotAMatchError(f"unrecognized token vector length {token_vector_length}") + + +class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class ControlAdapter_Config_Base(ABC, BaseModel): + default_settings: ControlAdapterDefaultSettings | None = Field(None) + + +class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base): + """Model config for Control LoRA models.""" + + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa) + format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) + + trigger_phrases: set[str] | None = Field(None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_control_lora(mod) + + return cls(**override_fields) + + @classmethod + def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None: + state_dict = mod.load_state_dict() + + if not is_state_dict_likely_flux_control(state_dict): + raise NotAMatchError("model state dict does not look like a Flux Control LoRA") + + +class LoRA_Diffusers_Config_Base(LoRA_Config_Base): + """Model config for LoRA/Diffusers models.""" + + # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates + # the weights format. FLUX Diffusers LoRAs are single files. + + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + if _get_flux_lora_format(mod): + return BaseModelType.Flux + + # If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA + path_to_weight_file = cls._get_weight_file_or_raise(mod) + state_dict = mod.load_state_dict(path_to_weight_file) + token_vector_length = lora_token_vector_length(state_dict) + + match token_vector_length: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized token vector length {token_vector_length}") + + @classmethod + def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path: + suffixes = ["bin", "safetensors"] + weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] + for wf in weight_files: + if wf.exists(): + return wf + raise NotAMatchError("missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") + + +class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py new file mode 100644 index 0000000000..ef1ab1fe77 --- /dev/null +++ b/invokeai/backend/model_manager/configs/main.py @@ -0,0 +1,692 @@ +from abc import ABC +from typing import Any, Literal, Self + +from pydantic import BaseModel, ConfigDict, Field + +from invokeai.backend.model_manager.configs.base import ( + Checkpoint_Config_Base, + Config_Base, + Diffusers_Config_Base, + SubmodelDefinition, +) +from invokeai.backend.model_manager.configs.clip_embed import get_clip_variant_type_from_config +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, + state_dict_has_any_keys_exact, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + FluxVariantType, + ModelFormat, + ModelType, + ModelVariantType, + SchedulerPredictionType, + SubModelType, +) +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor +from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES + +DEFAULTS_PRECISION = Literal["fp16", "fp32"] + + +class MainModelDefaultSettings(BaseModel): + vae: str | None = Field(default=None, description="Default VAE for this model (model key)") + vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model") + scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model") + steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model") + cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model") + cfg_rescale_multiplier: float | None = Field( + default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model" + ) + width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model") + height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model") + guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model") + + model_config = ConfigDict(extra="forbid") + + +class Main_Config_Base(ABC, BaseModel): + type: Literal[ModelType.Main] = Field(default=ModelType.Main) + trigger_phrases: set[str] | None = Field( + default=None, + description="Set of trigger phrases for this model", + ) + default_settings: MainModelDefaultSettings | None = Field( + default=None, + description="Default settings for this model", + ) + + +def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool: + bnb_nf4_keys = { + "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", + "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", + } + return any(key in state_dict for key in bnb_nf4_keys) + + +def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool: + return any(isinstance(v, GGMLTensor) for v in state_dict.values()) + + +def _has_main_keys(state_dict: dict[str | int, Any]) -> bool: + for key in state_dict.keys(): + if isinstance(key, int): + continue + elif key.startswith( + ( + "cond_stage_model.", + "first_stage_model.", + "model.diffusion_model.", + # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model". + # This prefix is typically used to distinguish between multiple models bundled in a single file. + "model.diffusion_model.double_blocks.", + ) + ): + return True + elif key.startswith("double_blocks.") and "ip_adapter" not in key: + # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be + # careful to avoid false positives on XLabs FLUX IP-Adapter models. + return True + return False + + +class Main_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base): + """Model config for main checkpoint models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + prediction_type: SchedulerPredictionType = Field() + variant: ModelVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_base(mod) + + prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**override_fields, prediction_type=prediction_type, variant=variant) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 768: + return BaseModelType.StableDiffusion1 + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + return BaseModelType.StableDiffusion2 + + key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: + return BaseModelType.StableDiffusionXL + elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: + return BaseModelType.StableDiffusionXLRefiner + + raise NotAMatchError("unable to determine base type from state dict") + + @classmethod + def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: + base = cls.model_fields["base"].default + + if base is BaseModelType.StableDiffusion2: + state_dict = mod.load_state_dict() + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + if "global_step" in state_dict: + if state_dict["global_step"] == 220000: + return SchedulerPredictionType.Epsilon + elif state_dict["global_step"] == 110000: + return SchedulerPredictionType.VPrediction + return SchedulerPredictionType.VPrediction + else: + return SchedulerPredictionType.Epsilon + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: + base = cls.model_fields["base"].default + + state_dict = mod.load_state_dict() + key_name = "model.diffusion_model.input_blocks.0.0.weight" + + if key_name not in state_dict: + raise NotAMatchError("unable to determine model variant from state dict") + + in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] + + match in_channels: + case 4: + return ModelVariantType.Normal + case 5: + # Only SD2 has a depth variant + assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" + return ModelVariantType.Depth + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatchError(f"unrecognized unet in_channels {in_channels} for base '{base}'") + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatchError("state dict does not look like a main model") + + +class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner) + + +def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + + # Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight". + # + # Known models that use the latter key: + # - https://civitai.com/models/885098?modelVersionId=990775 + # - https://civitai.com/models/1018060?modelVersionId=1596255 + # - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133 + # + # Input channels for known FLUX models: + # - Unquantized Dev and Schnell have in_channels=64 + # - BNB-NF4 Dev and Schnell have in_channels=1 + # - FLUX Fill has in_channels=384 + # - Unsure of quantized FLUX Fill models + # - Unsure of GGUF-quantized models + + in_channels = None + for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}: + if key in state_dict: + in_channels = state_dict[key].shape[1] + break + + if in_channels is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + return None + + # Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of + # certain keys to distinguish between them. + is_flux_dev = ( + "guidance_in.out_layer.weight" in state_dict + or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict + ) + + if is_flux_dev and in_channels == 384: + return FluxVariantType.DevFill + elif is_flux_dev: + return FluxVariantType.Dev + else: + # Must be a Schnell model...? + return FluxVariantType.Schnell + + +class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for main checkpoint models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_is_flux(mod) + + cls._validate_does_not_look_like_bnb_quantized(mod) + + cls._validate_does_not_look_like_gguf_quantized(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**override_fields, variant=variant) + + @classmethod + def _validate_is_flux(cls, mod: ModelOnDisk) -> None: + if not state_dict_has_any_keys_exact( + mod.load_state_dict(), + { + "double_blocks.0.img_attn.norm.key_norm.scale", + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", + }, + ): + raise NotAMatchError("state dict does not look like a FLUX checkpoint") + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatchError("unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatchError("state dict does not look like a main model") + + @classmethod + def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) + if has_bnb_nf4_keys: + raise NotAMatchError("state dict looks like bnb quantized nf4") + + @classmethod + def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk): + has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) + if has_ggml_tensors: + raise NotAMatchError("state dict looks like GGUF quantized") + + +class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for main checkpoint models.""" + + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_model_looks_like_bnb_quantized(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**override_fields, variant=variant) + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatchError("unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatchError("state dict does not look like a main model") + + @classmethod + def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) + if not has_bnb_nf4_keys: + raise NotAMatchError("state dict does not look like bnb quantized nf4") + + +class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for main checkpoint models.""" + + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_looks_like_gguf_quantized(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**override_fields, variant=variant) + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatchError("unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatchError("state dict does not look like a main model") + + @classmethod + def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None: + has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) + if not has_ggml_tensors: + raise NotAMatchError("state dict does not look like GGUF quantized") + + +class Main_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base): + prediction_type: SchedulerPredictionType = Field() + variant: ModelVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + # SD 1.x and 2.x + "StableDiffusionPipeline", + "StableDiffusionInpaintPipeline", + # SDXL + "StableDiffusionXLPipeline", + "StableDiffusionXLInpaintPipeline", + # SDXL Refiner + "StableDiffusionXLImg2ImgPipeline", + # TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase. + "LatentConsistencyModelPipeline", + }, + ) + + cls._validate_base(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + + repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **override_fields, + variant=variant, + prediction_type=prediction_type, + repo_variant=repo_variant, + ) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + # Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL). + unet_conf = get_config_dict_or_raise(mod.path / "unet" / "config.json") + cross_attention_dim = unet_conf.get("cross_attention_dim") + match cross_attention_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXLRefiner + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized cross_attention_dim {cross_attention_dim}") + + @classmethod + def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: + scheduler_conf = get_config_dict_or_raise(mod.path / "scheduler" / "scheduler_config.json") + + # TODO(psyche): Is epsilon the right default or should we raise if it's not present? + prediction_type = scheduler_conf.get("prediction_type", "epsilon") + + match prediction_type: + case "v_prediction": + return SchedulerPredictionType.VPrediction + case "epsilon": + return SchedulerPredictionType.Epsilon + case _: + raise NotAMatchError(f"unrecognized scheduler prediction_type {prediction_type}") + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: + base = cls.model_fields["base"].default + unet_config = get_config_dict_or_raise(mod.path / "unet" / "config.json") + in_channels = unet_config.get("in_channels") + + match in_channels: + case 4: + return ModelVariantType.Normal + case 5: + # Only SD2 has a depth variant + assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" + return ModelVariantType.Depth + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatchError(f"unrecognized unet in_channels {in_channels} for base '{base}'") + + +class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1) + + +class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2) + + +class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL) + + +class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner) + + +class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3) + submodels: dict[SubModelType, SubmodelDefinition] | None = Field( + description="Loadable submodels in this model", + default=None, + ) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + # This check implies the base type - no further validation needed. + raise_for_class_name( + common_config_paths(mod.path), + { + "StableDiffusion3Pipeline", + "SD3Transformer2DModel", + }, + ) + + submodels = override_fields.get("submodels") or cls._get_submodels_or_raise(mod) + + repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **override_fields, + submodels=submodels, + repo_variant=repo_variant, + ) + + @classmethod + def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]: + # Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json + config = get_config_dict_or_raise(common_config_paths(mod.path)) + + submodels: dict[SubModelType, SubmodelDefinition] = {} + + for key, value in config.items(): + # Anything that starts with an underscore is top-level metadata, not a submodel + if key.startswith("_") or not (isinstance(value, list) and len(value) == 2): + continue + # The key is something like "transformer" and is a submodel - it will be in a dir of the same name. + # The value value is something like ["diffusers", "SD3Transformer2DModel"] + _library_name, class_name = value + + match class_name: + case "CLIPTextModelWithProjection": + model_type = ModelType.CLIPEmbed + path_or_prefix = (mod.path / key).resolve().as_posix() + + # We need to read the config to determine the variant of the CLIP model. + clip_embed_config = get_config_dict_or_raise( + { + mod.path / key / "config.json", + mod.path / key / "model_index.json", + } + ) + variant = get_clip_variant_type_from_config(clip_embed_config) + submodels[SubModelType(key)] = SubmodelDefinition( + path_or_prefix=path_or_prefix, + model_type=model_type, + variant=variant, + ) + case "SD3Transformer2DModel": + model_type = ModelType.Main + path_or_prefix = (mod.path / key).resolve().as_posix() + variant = None + submodels[SubModelType(key)] = SubmodelDefinition( + path_or_prefix=path_or_prefix, + model_type=model_type, + variant=variant, + ) + case _: + pass + + return submodels + + +class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + # This check implies the base type - no further validation needed. + raise_for_class_name( + common_config_paths(mod.path), + { + "CogView4Pipeline", + }, + ) + + repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **override_fields, + repo_variant=repo_variant, + ) + + +class ExternalAPI_Config_Base(ABC, BaseModel): + """Model config for API-based models.""" + + format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise NotAMatchError("External API models cannot be built from disk") + + +class Main_ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o) + + +class Main_ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5) + + +class Main_ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3) + + +class Main_ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4) + + +class Main_ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) + + +class Video_Config_Base(ABC, BaseModel): + type: Literal[ModelType.Video] = Field(default=ModelType.Video) + trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None) + default_settings: MainModelDefaultSettings | None = Field( + description="Default settings for this model", default=None + ) + + +class Video_ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) + + +class Video_ExternalAPI_Runway_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) diff --git a/invokeai/backend/model_manager/configs/siglip.py b/invokeai/backend/model_manager/configs/siglip.py new file mode 100644 index 0000000000..62ca9494e2 --- /dev/null +++ b/invokeai/backend/model_manager/configs/siglip.py @@ -0,0 +1,44 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + common_config_paths, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base): + """Model config for SigLIP.""" + + type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "SiglipModel", + }, + ) + + return cls(**override_fields) diff --git a/invokeai/backend/model_manager/configs/spandrel.py b/invokeai/backend/model_manager/configs/spandrel.py new file mode 100644 index 0000000000..8ca8ad5f60 --- /dev/null +++ b/invokeai/backend/model_manager/configs/spandrel.py @@ -0,0 +1,54 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_file, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel + + +class Spandrel_Checkpoint_Config(Config_Base): + """Model config for Spandrel Image to Image models.""" + + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_spandrel_loads_model(mod) + + return cls(**override_fields) + + @classmethod + def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None: + try: + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were + # explored to avoid this: + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta + # device. Unfortunately, some Spandrel models perform operations during initialization that are not + # supported on meta tensors. + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to + # maintain it, and the risk of false positive detections is higher. + SpandrelImageToImageModel.load_from_file(mod.path) + except Exception as e: + raise NotAMatchError("model does not match SpandrelImageToImage heuristics") from e diff --git a/invokeai/backend/model_manager/configs/t2i_adapter.py b/invokeai/backend/model_manager/configs/t2i_adapter.py new file mode 100644 index 0000000000..865c4dc763 --- /dev/null +++ b/invokeai/backend/model_manager/configs/t2i_adapter.py @@ -0,0 +1,79 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base): + """Model config for T2I.""" + + type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + default_settings: ControlAdapterDefaultSettings | None = Field(None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "T2IAdapter", + }, + ) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + config_dict = get_config_dict_or_raise(common_config_paths(mod.path)) + + adapter_type = config_dict.get("adapter_type") + + match adapter_type: + case "full_adapter_xl": + return BaseModelType.StableDiffusionXL + case "full_adapter" | "light_adapter": + return BaseModelType.StableDiffusion1 + case _: + raise NotAMatchError(f"unrecognized adapter_type '{adapter_type}'") + + +class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) diff --git a/invokeai/backend/model_manager/configs/t5_encoder.py b/invokeai/backend/model_manager/configs/t5_encoder.py new file mode 100644 index 0000000000..3b6f2e733d --- /dev/null +++ b/invokeai/backend/model_manager/configs/t5_encoder.py @@ -0,0 +1,87 @@ +from typing import Any, Literal, Self + +from pydantic import Field + +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, + state_dict_has_any_keys_ending_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType + + +class T5Encoder_T5Encoder_Config(Config_Base): + """Configuration for T5 Encoder models in a bespoke, diffusers-like format. The model weights are expected to be in + a folder called text_encoder_2 inside the model directory, with a config file named model.safetensors.index.json.""" + + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) + format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "T5EncoderModel", + }, + ) + + cls.raise_if_doesnt_have_unquantized_config_file(mod) + + return cls(**override_fields) + + @classmethod + def raise_if_doesnt_have_unquantized_config_file(cls, mod: ModelOnDisk) -> None: + has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists() + + if not has_unquantized_config: + raise NotAMatchError("missing text_encoder_2/model.safetensors.index.json") + + +class T5Encoder_BnBLLMint8_Config(Config_Base): + """Configuration for T5 Encoder models quantized by bitsandbytes' LLM.int8.""" + + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) + format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "T5EncoderModel", + }, + ) + + cls.raise_if_filename_doesnt_look_like_bnb_quantized(mod) + + cls.raise_if_state_dict_doesnt_look_like_bnb_quantized(mod) + + return cls(**override_fields) + + @classmethod + def raise_if_filename_doesnt_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) + if not filename_looks_like_bnb: + raise NotAMatchError("filename does not look like bnb quantized llm_int8") + + @classmethod + def raise_if_state_dict_doesnt_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_scb_key_suffix = state_dict_has_any_keys_ending_with(mod.load_state_dict(), "SCB") + if not has_scb_key_suffix: + raise NotAMatchError("state dict does not look like bnb quantized llm_int8") diff --git a/invokeai/backend/model_manager/configs/textual_inversion.py b/invokeai/backend/model_manager/configs/textual_inversion.py new file mode 100644 index 0000000000..c827f5234d --- /dev/null +++ b/invokeai/backend/model_manager/configs/textual_inversion.py @@ -0,0 +1,156 @@ +from abc import ABC +from pathlib import Path +from typing import ( + Literal, + Self, +) + +import torch +from pydantic import BaseModel, Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class TI_Config_Base(ABC, BaseModel): + type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None: + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod, path) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: + try: + p = path or mod.path + + if not p.exists(): + return False + + if p.is_dir(): + return False + + if p.name in [f"learned_embeds.{s}" for s in mod.weight_files()]: + return True + + state_dict = mod.load_state_dict(p) + + # Heuristic: textual inversion embeddings have these keys + if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()): + return True + + # Heuristic: small state dict with all tensor values + if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): + return True + + return False + except Exception: + return False + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: + p = path or mod.path + + try: + state_dict = mod.load_state_dict(p) + except Exception as e: + raise NotAMatchError(f"unable to load state dict from {p}: {e}") from e + + try: + if "string_to_token" in state_dict: + token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] + elif "emb_params" in state_dict: + token_dim = state_dict["emb_params"].shape[-1] + elif "clip_g" in state_dict: + token_dim = state_dict["clip_g"].shape[-1] + else: + token_dim = list(state_dict.values())[0].shape[0] + except Exception as e: + raise NotAMatchError(f"unable to determine token dimension from state dict in {p}: {e}") from e + + match token_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized token dimension {token_dim}") + + +class TI_File_Config_Base(TI_Config_Base): + """Model config for textual inversion embeddings.""" + + format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + if not cls._file_looks_like_embedding(mod): + raise NotAMatchError("model does not look like a textual inversion embedding file") + + cls._validate_base(mod) + + return cls(**override_fields) + + +class TI_File_SD1_Config(TI_File_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class TI_File_SD2_Config(TI_File_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class TI_Folder_Config_Base(TI_Config_Base): + """Model config for textual inversion embeddings.""" + + format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + for p in mod.weight_files(): + if cls._file_looks_like_embedding(mod, p): + cls._validate_base(mod, p) + return cls(**override_fields) + + raise NotAMatchError("model does not look like a textual inversion embedding folder") + + +class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) diff --git a/invokeai/backend/model_manager/configs/unknown.py b/invokeai/backend/model_manager/configs/unknown.py new file mode 100644 index 0000000000..10aad75566 --- /dev/null +++ b/invokeai/backend/model_manager/configs/unknown.py @@ -0,0 +1,24 @@ +from typing import Any, Literal, Self + +from pydantic import Field + +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class Unknown_Config(Config_Base): + """Model config for unknown models, used as a fallback when we cannot identify a model.""" + + base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown) + type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown) + format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + raise NotAMatchError("unknown model config cannot match any model") diff --git a/invokeai/backend/model_manager/configs/vae.py b/invokeai/backend/model_manager/configs/vae.py new file mode 100644 index 0000000000..6c51fd770e --- /dev/null +++ b/invokeai/backend/model_manager/configs/vae.py @@ -0,0 +1,166 @@ +import re +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, + state_dict_has_any_keys_starting_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + +REGEX_TO_BASE: dict[str, BaseModelType] = { + r"xl": BaseModelType.StableDiffusionXL, + r"sd2": BaseModelType.StableDiffusion2, + r"vae": BaseModelType.StableDiffusion1, + r"FLUX.1-schnell_ae": BaseModelType.Flux, +} + + +class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base): + """Model config for standalone VAE models.""" + + type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_vae(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None: + if not state_dict_has_any_keys_starting_with( + mod.load_state_dict(), + { + "encoder.conv_in", + "decoder.conv_in", + }, + ): + raise NotAMatchError("model does not match Checkpoint VAE heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name + for regexp, base in REGEX_TO_BASE.items(): + if re.search(regexp, mod.path.name, re.IGNORECASE): + return base + + raise NotAMatchError("cannot determine base type") + + +class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class VAE_Diffusers_Config_Base(Diffusers_Config_Base): + """Model config for standalone VAE models (diffusers version).""" + + type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "AutoencoderKL", + "AutoencoderTiny", + }, + ) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool: + # Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. + return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] + + @classmethod + def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool: + # Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down + # by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best + # we can do is guess based on name. + return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE)) + + @classmethod + def _guess_name(cls, mod: ModelOnDisk) -> str: + name = mod.path.name + if name == "vae": + name = mod.path.parent.name + return name + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + config_dict = get_config_dict_or_raise(common_config_paths(mod.path)) + if cls._config_looks_like_sdxl(config_dict): + return BaseModelType.StableDiffusionXL + elif cls._name_looks_like_sdxl(mod): + return BaseModelType.StableDiffusionXL + else: + # TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO. + return BaseModelType.StableDiffusion1 + + +class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)