refactor(mm): split configs into separate files

This commit is contained in:
psychedelicious
2025-10-03 19:24:29 +10:00
parent 6cc67e53a8
commit d93d4242f9
20 changed files with 2982 additions and 0 deletions

View File

@@ -0,0 +1,243 @@
from abc import ABC, abstractmethod
from enum import Enum
from inspect import isabstract
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
Self,
Type,
)
from pydantic import BaseModel, ConfigDict, Field, Tag
from pydantic_core import PydanticUndefined
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
)
if TYPE_CHECKING:
pass
class Config_Base(ABC, BaseModel):
"""
Abstract base class for model configurations. A model config describes a specific combination of model base, type and
format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format
would have base=sd-1, type=main, format=checkpoint.
To create a new config type, inherit from this class and implement its interface:
- Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be
called during model installation to determine the correct config class for a model.
- Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A
default must be provided for each of these fields.
If multiple combinations of base, type and format need to be supported, create a separate subclass for each.
See MinimalConfigExample in test_model_probe.py for an example implementation.
"""
# These fields are common to all model configs.
key: str = Field(
default_factory=uuid_string,
description="A unique key for this model.",
)
hash: str = Field(
description="The hash of the model file(s).",
)
path: str = Field(
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.",
)
file_size: int = Field(
description="The size of the model in bytes.",
)
name: str = Field(
description="Name of the model.",
)
description: str | None = Field(
default=None,
description="Model description",
)
source: str = Field(
description="The original source of the model (path, URL or repo_id).",
)
source_type: ModelSourceType = Field(
description="The type of source",
)
source_api_response: str | None = Field(
default=None,
description="The original API response from the source, as stringified JSON.",
)
cover_image: str | None = Field(
default=None,
description="Url for image to preview model",
)
usage_info: str | None = Field(
default=None,
description="Usage information for this model",
)
CONFIG_CLASSES: ClassVar[set[Type["Config_Base"]]] = set()
"""Set of all non-abstract subclasses of Config_Base, for use during model probing. In other words, this is the set
of all known model config types."""
model_config = ConfigDict(
validate_assignment=True,
json_schema_serialization_defaults_required=True,
json_schema_mode_override="serialization",
)
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Register non-abstract subclasses so we can iterate over them later during model probing. Note that
# isabstract() will return False if the class does not have any abstract methods, even if it inherits from ABC.
# We must check for ABC lest we unintentionally register some abstract model config classes.
if not isabstract(cls) and ABC not in cls.__bases__:
cls.CONFIG_CLASSES.add(cls)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
# Ensure that model configs define 'base', 'type' and 'format' fields and provide defaults for them. Each
# subclass is expected to represent a single combination of base, type and format.
#
# This pydantic dunder method is called after the pydantic model for a class is created. The normal
# __init_subclass__ is too early to do this check.
for name in ("type", "base", "format"):
if name not in cls.model_fields:
raise NotImplementedError(f"{cls.__name__} must define a '{name}' field")
if cls.model_fields[name].default is PydanticUndefined:
raise NotImplementedError(f"{cls.__name__} must define a default for the '{name}' field")
@classmethod
def get_tag(cls) -> Tag:
"""Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized,
pydantic uses the tag to determine which subclass to instantiate.
The tag is a dot-separated string of the type, format, base and variant (if applicable).
"""
tag_strings: list[str] = []
for name in ("type", "format", "base", "variant"):
if field := cls.model_fields.get(name):
# The check in __pydantic_init_subclass__ ensures that type, format and base are always present with
# defaults. variant does not require a default, but if it has one, we need to add it to the tag. We can
# check for the presence of a default by seeing if it's not PydanticUndefined, a sentinel value used by
# pydantic to indicate that no default was provided.
if field.default is not PydanticUndefined:
# We expect each of these fields has an Enum for its default; we want the value of the enum.
tag_strings.append(field.default.value)
return Tag(".".join(tag_strings))
@staticmethod
def get_model_discriminator_value(v: Any) -> str:
"""Computes the discriminator value for a model config discriminated union."""
# This is called by pydantic during deserialization and serialization to determine which model the data
# represents. It can get either a dict (during deserialization) or an instance of a Config_Base subclass
# (during serialization).
#
# See: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
if isinstance(v, Config_Base):
# We have an instance of a ModelConfigBase subclass - use its tag directly.
return v.get_tag().tag
if isinstance(v, dict):
# We have a dict - attempt to compute a tag from its fields.
tag_strings: list[str] = []
if type_ := v.get("type"):
if isinstance(type_, Enum):
type_ = str(type_.value)
elif not isinstance(type_, str):
raise TypeError("Model config dict 'type' field must be a string or Enum")
tag_strings.append(type_)
if format_ := v.get("format"):
if isinstance(format_, Enum):
format_ = str(format_.value)
elif not isinstance(format_, str):
raise TypeError("Model config dict 'format' field must be a string or Enum")
tag_strings.append(format_)
if base_ := v.get("base"):
if isinstance(base_, Enum):
base_ = str(base_.value)
elif not isinstance(base_, str):
raise TypeError("Model config dict 'base' field must be a string or Enum")
tag_strings.append(base_)
# Special case: CLIP Embed models also need the variant to distinguish them.
if (
type_ == ModelType.CLIPEmbed.value
and format_ == ModelFormat.Diffusers.value
and base_ == BaseModelType.Any.value
):
if variant_ := v.get("variant"):
if isinstance(variant_, Enum):
variant_ = variant_.value
elif not isinstance(variant_, str):
raise TypeError("Model config dict 'variant' field must be a string or Enum")
tag_strings.append(variant_)
else:
raise ValueError("CLIP Embed model config dict must include a 'variant' field")
return ".".join(tag_strings)
else:
raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance")
@abstractmethod
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
"""Given the model on disk and any override fields, attempt to construct an instance of this config class.
This method serves to identify whether the model on disk matches this config class, and if so, to extract any
additional metadata needed to instantiate the config.
Implementations should raise a NotAMatchError if the model does not match this config class."""
raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}")
class Checkpoint_Config_Base(ABC, BaseModel):
"""Base class for checkpoint-style models."""
config_path: str | None = Field(
description="Path to the config for this model, if any.",
default=None,
)
class Diffusers_Config_Base(ABC, BaseModel):
"""Base class for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
repo_variant: ModelRepoVariant = Field(ModelRepoVariant.Default)
@classmethod
def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant:
# get all files ending in .bin or .safetensors
weight_files = list(mod.path.glob("**/*.safetensors"))
weight_files.extend(list(mod.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
variant: AnyVariant | None = None
model_config = ConfigDict(protected_namespaces=())

View File

@@ -0,0 +1,91 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
ModelFormat,
ModelType,
)
def get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None:
try:
hidden_size = config.get("hidden_size")
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None
except Exception:
return None
class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
{
mod.path / "config.json",
mod.path / "text_encoder" / "config.json",
},
{
"CLIPModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
},
)
cls._validate_variant(mod)
return cls(**override_fields)
@classmethod
def _validate_variant(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model variant does not match this config class."""
expected_variant = cls.model_fields["variant"].default
config = get_config_dict_or_raise(
{
mod.path / "config.json",
mod.path / "text_encoder" / "config.json",
},
)
recognized_variant = get_clip_variant_type_from_config(config)
if recognized_variant is None:
raise NotAMatchError("unable to determine CLIP variant from config")
if expected_variant is not recognized_variant:
raise NotAMatchError(f"variant is {recognized_variant}, not {expected_variant}")
class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)

View File

@@ -0,0 +1,44 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
common_config_paths,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for CLIPVision."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"CLIPVisionModelWithProjection",
},
)
return cls(**override_fields)

View File

@@ -0,0 +1,195 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_starting_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
default_settings: ControlAdapterDefaultSettings | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"ControlNetModel",
"FluxControlNetModel",
},
)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
config_dict = get_config_dict_or_raise(common_config_paths(mod.path))
if config_dict.get("_class_name") == "FluxControlNetModel":
return BaseModelType.Flux
dimension = config_dict.get("cross_attention_dim")
match dimension:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
# No obvious way to distinguish between sd2-base and sd2-768, but we don't really differentiate them
# anyway.
return BaseModelType.StableDiffusion2
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized cross_attention_dim {dimension}")
class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
default_settings: ControlAdapterDefaultSettings | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_controlnet(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
if not state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
{
"controlnet",
"control_model",
"input_blocks",
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
# delicate.
"controlnet_blocks",
},
):
raise NotAMatchError("state dict does not look like a ControlNet checkpoint")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
if is_state_dict_xlabs_controlnet(state_dict) or is_state_dict_instantx_controlnet(state_dict):
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
# get_format()?
return BaseModelType.Flux
for key in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"controlnet_mid_block.bias",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
):
if key not in state_dict:
continue
width = state_dict[key].shape[-1]
match width:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 2048:
return BaseModelType.StableDiffusionXL
case 1280:
return BaseModelType.StableDiffusionXL
case _:
pass
raise NotAMatchError("unable to determine base type from state dict")
class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)

View File

@@ -0,0 +1,340 @@
import logging
from pathlib import Path
from typing import (
Union,
)
from pydantic import Discriminator, TypeAdapter, ValidationError
from typing_extensions import Annotated, Any
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_L_Config
from invokeai.backend.model_manager.configs.clip_vision import CLIPVision_Diffusers_Config
from invokeai.backend.model_manager.configs.controlnet import (
ControlNet_Checkpoint_FLUX_Config,
ControlNet_Checkpoint_SD1_Config,
ControlNet_Checkpoint_SD2_Config,
ControlNet_Checkpoint_SDXL_Config,
ControlNet_Diffusers_FLUX_Config,
ControlNet_Diffusers_SD1_Config,
ControlNet_Diffusers_SD2_Config,
ControlNet_Diffusers_SDXL_Config,
)
from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
from invokeai.backend.model_manager.configs.ip_adapter import (
IPAdapter_Checkpoint_FLUX_Config,
IPAdapter_Checkpoint_SD1_Config,
IPAdapter_Checkpoint_SD2_Config,
IPAdapter_Checkpoint_SDXL_Config,
IPAdapter_InvokeAI_SD1_Config,
IPAdapter_InvokeAI_SD2_Config,
IPAdapter_InvokeAI_SDXL_Config,
)
from invokeai.backend.model_manager.configs.llava_onevision import LlavaOnevision_Diffusers_Config
from invokeai.backend.model_manager.configs.lora import (
ControlLoRA_LyCORIS_FLUX_Config,
LoRA_Diffusers_FLUX_Config,
LoRA_Diffusers_SD1_Config,
LoRA_Diffusers_SD2_Config,
LoRA_Diffusers_SDXL_Config,
LoRA_LyCORIS_FLUX_Config,
LoRA_LyCORIS_SD1_Config,
LoRA_LyCORIS_SD2_Config,
LoRA_LyCORIS_SDXL_Config,
LoRA_OMI_FLUX_Config,
LoRA_OMI_SDXL_Config,
)
from invokeai.backend.model_manager.configs.main import (
Main_BnBNF4_FLUX_Config,
Main_Checkpoint_FLUX_Config,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
Main_Diffusers_CogView4_Config,
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
Main_Diffusers_SD3_Config,
Main_Diffusers_SDXL_Config,
Main_Diffusers_SDXLRefiner_Config,
Main_ExternalAPI_ChatGPT4o_Config,
Main_ExternalAPI_FluxKontext_Config,
Main_ExternalAPI_Gemini2_5_Config,
Main_ExternalAPI_Imagen3_Config,
Main_ExternalAPI_Imagen4_Config,
Main_GGUF_FLUX_Config,
Video_ExternalAPI_Runway_Config,
Video_ExternalAPI_Veo3_Config,
)
from invokeai.backend.model_manager.configs.siglip import SigLIP_Diffusers_Config
from invokeai.backend.model_manager.configs.spandrel import Spandrel_Checkpoint_Config
from invokeai.backend.model_manager.configs.t2i_adapter import (
T2IAdapter_Diffusers_SD1_Config,
T2IAdapter_Diffusers_SDXL_Config,
)
from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config
from invokeai.backend.model_manager.configs.textual_inversion import (
TI_File_SD1_Config,
TI_File_SD2_Config,
TI_File_SDXL_Config,
TI_Folder_SD1_Config,
TI_Folder_SD2_Config,
TI_Folder_SDXL_Config,
)
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
from invokeai.backend.model_manager.configs.vae import (
VAE_Checkpoint_FLUX_Config,
VAE_Checkpoint_SD1_Config,
VAE_Checkpoint_SD2_Config,
VAE_Checkpoint_SDXL_Config,
VAE_Diffusers_SD1_Config,
VAE_Diffusers_SDXL_Config,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelSourceType,
ModelType,
variant_type_adapter,
)
logger = logging.getLogger(__name__)
app_config = get_config()
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
AnyModelConfig = Annotated[
Union[
# Main (Pipeline) - diffusers format
Annotated[Main_Diffusers_SD1_Config, Main_Diffusers_SD1_Config.get_tag()],
Annotated[Main_Diffusers_SD2_Config, Main_Diffusers_SD2_Config.get_tag()],
Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()],
Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()],
Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()],
Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()],
# Main (Pipeline) - checkpoint format
Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()],
Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()],
Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()],
Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
# Main (Pipeline) - quantized formats
Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
# VAE - checkpoint format
Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()],
Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()],
Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()],
Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()],
# VAE - diffusers format
Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()],
Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()],
# ControlNet - checkpoint format
Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()],
Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()],
Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()],
Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()],
# ControlNet - diffusers format
Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()],
Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()],
Annotated[ControlNet_Diffusers_SDXL_Config, ControlNet_Diffusers_SDXL_Config.get_tag()],
Annotated[ControlNet_Diffusers_FLUX_Config, ControlNet_Diffusers_FLUX_Config.get_tag()],
# LoRA - LyCORIS format
Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()],
Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()],
Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()],
Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()],
# LoRA - OMI format
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
# LoRA - diffusers format
Annotated[LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SD1_Config.get_tag()],
Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()],
Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()],
Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()],
# ControlLoRA - diffusers format
Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()],
# T5 Encoder - all formats
Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()],
Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()],
# TI - file format
Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()],
Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()],
Annotated[TI_File_SDXL_Config, TI_File_SDXL_Config.get_tag()],
# TI - folder format
Annotated[TI_Folder_SD1_Config, TI_Folder_SD1_Config.get_tag()],
Annotated[TI_Folder_SD2_Config, TI_Folder_SD2_Config.get_tag()],
Annotated[TI_Folder_SDXL_Config, TI_Folder_SDXL_Config.get_tag()],
# IP Adapter - InvokeAI format
Annotated[IPAdapter_InvokeAI_SD1_Config, IPAdapter_InvokeAI_SD1_Config.get_tag()],
Annotated[IPAdapter_InvokeAI_SD2_Config, IPAdapter_InvokeAI_SD2_Config.get_tag()],
Annotated[IPAdapter_InvokeAI_SDXL_Config, IPAdapter_InvokeAI_SDXL_Config.get_tag()],
# IP Adapter - checkpoint format
Annotated[IPAdapter_Checkpoint_SD1_Config, IPAdapter_Checkpoint_SD1_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_SD2_Config, IPAdapter_Checkpoint_SD2_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_SDXL_Config, IPAdapter_Checkpoint_SDXL_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_FLUX_Config, IPAdapter_Checkpoint_FLUX_Config.get_tag()],
# T2I Adapter - diffusers format
Annotated[T2IAdapter_Diffusers_SD1_Config, T2IAdapter_Diffusers_SD1_Config.get_tag()],
Annotated[T2IAdapter_Diffusers_SDXL_Config, T2IAdapter_Diffusers_SDXL_Config.get_tag()],
# Misc models
Annotated[Spandrel_Checkpoint_Config, Spandrel_Checkpoint_Config.get_tag()],
Annotated[CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_G_Config.get_tag()],
Annotated[CLIPEmbed_Diffusers_L_Config, CLIPEmbed_Diffusers_L_Config.get_tag()],
Annotated[CLIPVision_Diffusers_Config, CLIPVision_Diffusers_Config.get_tag()],
Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()],
Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()],
Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()],
# Main - external API
Annotated[Main_ExternalAPI_ChatGPT4o_Config, Main_ExternalAPI_ChatGPT4o_Config.get_tag()],
Annotated[Main_ExternalAPI_Gemini2_5_Config, Main_ExternalAPI_Gemini2_5_Config.get_tag()],
Annotated[Main_ExternalAPI_Imagen3_Config, Main_ExternalAPI_Imagen3_Config.get_tag()],
Annotated[Main_ExternalAPI_Imagen4_Config, Main_ExternalAPI_Imagen4_Config.get_tag()],
Annotated[Main_ExternalAPI_FluxKontext_Config, Main_ExternalAPI_FluxKontext_Config.get_tag()],
# Video - external API
Annotated[Video_ExternalAPI_Veo3_Config, Video_ExternalAPI_Veo3_Config.get_tag()],
Annotated[Video_ExternalAPI_Runway_Config, Video_ExternalAPI_Runway_Config.get_tag()],
# Unknown model (fallback)
Annotated[Unknown_Config, Unknown_Config.get_tag()],
],
Discriminator(Config_Base.get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
class ModelConfigFactory:
@staticmethod
def from_dict(fields: dict[str, Any]) -> AnyModelConfig:
"""Return the appropriate config object from raw dict values."""
model = AnyModelConfigValidator.validate_python(fields)
return model
@staticmethod
def build_common_fields(
mod: ModelOnDisk,
override_fields: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Builds the common fields for all model configs.
Args:
mod: The model on disk to extract fields from.
overrides: A optional dictionary of fields to override. These fields will take precedence over the values
extracted from the model on disk.
- Casts string fields to their Enum types.
- Does not validate the fields against the model config schema.
"""
_overrides: dict[str, Any] = override_fields or {}
fields: dict[str, Any] = {}
if "type" in _overrides:
fields["type"] = ModelType(_overrides["type"])
if "format" in _overrides:
fields["format"] = ModelFormat(_overrides["format"])
if "base" in _overrides:
fields["base"] = BaseModelType(_overrides["base"])
if "source_type" in _overrides:
fields["source_type"] = ModelSourceType(_overrides["source_type"])
if "variant" in _overrides:
fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"])
fields["path"] = mod.path.as_posix()
fields["source"] = _overrides.get("source") or fields["path"]
fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path
fields["name"] = _overrides.get("name") or mod.name
fields["hash"] = _overrides.get("hash") or mod.hash()
fields["key"] = _overrides.get("key") or uuid_string()
fields["description"] = _overrides.get("description")
fields["file_size"] = _overrides.get("file_size") or mod.size()
return fields
@staticmethod
def from_model_on_disk(
mod: str | Path | ModelOnDisk,
override_fields: dict[str, Any] | None = None,
hash_algo: HASHING_ALGORITHMS = "blake3_single",
) -> AnyModelConfig:
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
if isinstance(mod, Path | str):
mod = ModelOnDisk(Path(mod), hash_algo)
# We will always need these fields to build any model config.
fields = ModelConfigFactory.build_common_fields(mod, override_fields)
# Store results as a mapping of config class to either an instance of that class or an exception
# that was raised when trying to build it.
results: dict[str, AnyModelConfig | Exception] = {}
# Try to build an instance of each model config class that uses the classify API.
# Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
# Other exceptions may be raised if something unexpected happens during matching or building.
for config_class in Config_Base.CONFIG_CLASSES:
class_name = config_class.__name__
try:
instance = config_class.from_model_on_disk(mod, fields)
# Technically, from_model_on_disk returns a Config_Base, but in practice it will always be a member of
# the AnyModelConfig union.
results[class_name] = instance # type: ignore
except NotAMatchError as e:
results[class_name] = e
logger.debug(f"No match for {config_class.__name__} on model {mod.name}")
except ValidationError as e:
# This means the model matched, but we couldn't create the pydantic model instance for the config.
# Maybe invalid overrides were provided?
results[class_name] = e
logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}")
except Exception as e:
results[class_name] = e
logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
matches = [r for r in results.values() if isinstance(r, Config_Base)]
if not matches and app_config.allow_unknown_models:
logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config")
return Unknown_Config(**fields)
if len(matches) > 1:
# We have multiple matches, in which case at most 1 is correct. We need to pick one.
#
# Known cases:
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
# a config.json file. Prefer the main model.
# Sort the matching according to known special cases.
def sort_key(m: AnyModelConfig) -> int:
match m.type:
case ModelType.Main:
return 0
case ModelType.LoRA:
return 1
case ModelType.CLIPEmbed:
return 2
case _:
return 3
matches.sort(key=sort_key)
logger.warning(
f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}."
)
instance = matches[0]
logger.info(f"Model {mod.name} classified as {type(instance).__name__}")
return instance

View File

@@ -0,0 +1,40 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_file,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class FLUXRedux_Checkpoint_Config(Config_Base):
"""Model config for FLUX Tools Redux model."""
type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
if not is_state_dict_likely_flux_redux(mod.load_state_dict()):
raise NotAMatchError("model does not match FLUX Tools Redux heuristics")
return cls(**override_fields)

View File

@@ -0,0 +1,182 @@
import json
from functools import cache
from pathlib import Path
from pydantic import BaseModel, ValidationError
from pydantic_core import CoreSchema, SchemaValidator
from typing_extensions import Any
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
class NotAMatchError(Exception):
"""Exception for when a model does not match a config class.
Args:
reason: The reason why the model did not match.
"""
def __init__(self, reason: str):
super().__init__(reason)
def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]:
paths_to_check = config_path if isinstance(config_path, set) else {config_path}
problems: dict[Path, str] = {}
for p in paths_to_check:
if not p.exists():
problems[p] = "file does not exist"
continue
try:
with open(p, "r") as file:
config = json.load(file)
return config
except Exception as e:
problems[p] = str(e)
continue
raise NotAMatchError(f"unable to load config file(s): {problems}")
def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> str:
"""Load the diffusers/transformers model config file and return the class name.
Raises:
NotAMatch if the config file is missing or does not contain a valid class name.
"""
config = get_config_dict_or_raise(config_path)
try:
if "_class_name" in config:
# This is a diffusers-style config
config_class_name = config["_class_name"]
elif "architectures" in config:
# This is a transformers-style config
config_class_name = config["architectures"][0]
else:
raise ValueError("missing _class_name or architectures field")
except Exception as e:
raise NotAMatchError(f"unable to determine class name from config file: {config_path}") from e
if not isinstance(config_class_name, str):
raise NotAMatchError(f"_class_name or architectures field is not a string: {config_class_name}")
return config_class_name
def raise_for_class_name(config_path: Path | set[Path], expected: set[str]) -> None:
"""Get the class name from the config file and raise NotAMatch if it is not in the expected set.
Args:
config_path: The path to the config file.
expected: The expected class names.
Raises:
NotAMatch if the class name is not in the expected set.
"""
class_name = get_class_name_from_config_dict_or_raise(config_path)
if class_name not in expected:
raise NotAMatchError(f"invalid class name from config: {class_name}")
def raise_for_override_fields(candidate_config_class: type[BaseModel], override_fields: dict[str, Any]) -> None:
"""Check if the provided override fields are valid for the config class using pydantic.
For example, if the candidate config class has a field "base" of type Literal[BaseModelType.StableDiffusion1], and
the override fields contain "base": BaseModelType.Flux, this function will raise NotAMatch.
Args:
candidate_config_class: The config class that is being tested.
override_fields: The override fields provided by the user.
Raises:
NotAMatch if any override field is invalid for the config class.
"""
for field_name, override_value in override_fields.items():
if field_name not in candidate_config_class.model_fields:
raise NotAMatchError(f"unknown override field: {field_name}")
try:
PydanticFieldValidator.validate_field(candidate_config_class, field_name, override_value)
except ValidationError as e:
raise NotAMatchError(f"invalid override for field '{field_name}': {e}") from e
def raise_if_not_file(mod: ModelOnDisk) -> None:
"""Raise NotAMatch if the model path is not a file."""
if not mod.path.is_file():
raise NotAMatchError("model path is not a file")
def raise_if_not_dir(mod: ModelOnDisk) -> None:
"""Raise NotAMatch if the model path is not a directory."""
if not mod.path.is_dir():
raise NotAMatchError("model path is not a directory")
def state_dict_has_any_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
"""Returns true if the state dict has any of the specified keys."""
_keys = {keys} if isinstance(keys, str) else keys
return any(key in state_dict for key in _keys)
def state_dict_has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
"""Returns true if the state dict has any keys starting with any of the specified prefixes."""
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str))
def state_dict_has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
"""Returns true if the state dict has any keys ending with any of the specified suffixes."""
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str))
def common_config_paths(path: Path) -> set[Path]:
"""Returns common config file paths for models stored in directories."""
return {path / "config.json", path / "model_index.json"}
class PydanticFieldValidator:
"""Utility class for validating individual fields of a Pydantic model without instantiating the whole model.
See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
"""
@staticmethod
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
"""Find the Pydantic core schema for a specific field in a model."""
schema: CoreSchema = model.__pydantic_core_schema__.copy()
# we shallow copied, be careful not to mutate the original schema!
assert schema["type"] in ["definitions", "model"]
# find the field schema
field_schema = schema["schema"] # type: ignore
while "fields" not in field_schema:
field_schema = field_schema["schema"] # type: ignore
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
# if the original schema is a definition schema, replace the model schema with the field schema
if schema["type"] == "definitions":
schema["schema"] = field_schema
return schema
else:
return field_schema
@cache
@staticmethod
def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
"""Get a SchemaValidator for a specific field in a model."""
return SchemaValidator(PydanticFieldValidator.find_field_schema(model, field_name))
@staticmethod
def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
"""Validate a value for a specific field in a model."""
return PydanticFieldValidator.get_validator(model, field_name).validate_python(value)

View File

@@ -0,0 +1,180 @@
from abc import ABC
from typing import (
Literal,
Self,
)
from pydantic import BaseModel, Field
from typing_extensions import Any
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_starting_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class IPAdapter_Config_Base(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base):
"""Model config for IP Adapter diffusers format models."""
format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI)
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_has_weights_file(mod)
cls._validate_has_image_encoder_metadata_file(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None:
weights_file = mod.path / "ip_adapter.bin"
if not weights_file.exists():
raise NotAMatchError("missing ip_adapter.bin weights file")
@classmethod
def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None:
image_encoder_metadata_file = mod.path / "image_encoder.txt"
if not image_encoder_metadata_file.exists():
raise NotAMatchError("missing image_encoder.txt metadata file")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
try:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
except Exception as e:
raise NotAMatchError(f"unable to determine cross attention dimension: {e}") from e
match cross_attention_dim:
case 1280:
return BaseModelType.StableDiffusionXL
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case _:
raise NotAMatchError(f"unrecognized cross attention dimension {cross_attention_dim}")
class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_ip_adapter(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
if not state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
{
"image_proj.",
"ip_adapter.",
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
"ip_adapter_proj_model.",
},
):
raise NotAMatchError("model does not match Checkpoint IP Adapter heuristics")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
if is_state_dict_xlabs_ip_adapter(state_dict):
return BaseModelType.Flux
try:
cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1]
except Exception as e:
raise NotAMatchError(f"unable to determine cross attention dimension: {e}") from e
match cross_attention_dim:
case 1280:
return BaseModelType.StableDiffusionXL
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case _:
raise NotAMatchError(f"unrecognized cross attention dimension {cross_attention_dim}")
class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)

View File

@@ -0,0 +1,42 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
common_config_paths,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelType,
)
class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision)
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"LlavaOnevisionForConditionalGeneration",
},
)
return cls(**override_fields)

View File

@@ -0,0 +1,323 @@
from abc import ABC
from pathlib import Path
from typing import (
Any,
Literal,
Self,
)
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Any
from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.base import (
Config_Base,
)
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_ending_with,
state_dict_has_any_keys_starting_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
FluxLoRAFormat,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
class LoraModelDefaultSettings(BaseModel):
weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model")
model_config = ConfigDict(extra="forbid")
class LoRA_Config_Base(ABC, BaseModel):
"""Base class for LoRA models."""
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
trigger_phrases: set[str] | None = Field(
default=None,
description="Set of trigger phrases for this model",
)
default_settings: LoraModelDefaultSettings | None = Field(
default=None,
description="Default settings for this model",
)
def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
# TODO(psyche): Moving this import to the function to avoid circular imports. Refactor later.
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
state_dict = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(state_dict, mod.metadata())
return value
class LoRA_OMI_Config_Base(LoRA_Config_Base):
format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_omi_lora(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_omi_lora(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model metadata does not look like an OMI LoRA."""
flux_format = _get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatchError("model looks like ControlLoRA or Diffusers LoRA")
metadata = mod.metadata()
metadata_looks_like_omi_lora = (
bool(metadata.get("modelspec.sai_model_spec"))
and metadata.get("ot_branch") == "omi_format"
and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
)
if not metadata_looks_like_omi_lora:
raise NotAMatchError("metadata does not look like OMI LoRA")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]:
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
if architecture == stable_diffusion_xl_1_lora:
return BaseModelType.StableDiffusionXL
elif architecture == flux_dev_1_lora:
return BaseModelType.Flux
else:
raise NotAMatchError(f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class LoRA_LyCORIS_Config_Base(LoRA_Config_Base):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_lora(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
# First rule out ControlLoRA and Diffusers LoRA
flux_format = _get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatchError("model looks like ControlLoRA or Diffusers LoRA")
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
{
"lora_te_",
"lora_unet_",
"lora_te1_",
"lora_te2_",
"lora_transformer_",
},
)
has_key_with_lora_suffix = state_dict_has_any_keys_ending_with(
mod.load_state_dict(),
{
"to_k_lora.up.weight",
"to_q_lora.down.weight",
"lora_A.weight",
"lora_B.weight",
},
)
if not has_key_with_lora_prefix and not has_key_with_lora_suffix:
raise NotAMatchError("model does not match LyCORIS LoRA heuristics")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
if _get_flux_lora_format(mod):
return BaseModelType.Flux
state_dict = mod.load_state_dict()
# If we've gotten here, we assume that the model is a Stable Diffusion model
token_vector_length = lora_token_vector_length(state_dict)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise NotAMatchError(f"unrecognized token vector length {token_vector_length}")
class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class ControlAdapter_Config_Base(ABC, BaseModel):
default_settings: ControlAdapterDefaultSettings | None = Field(None)
class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base):
"""Model config for Control LoRA models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa)
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
trigger_phrases: set[str] | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_control_lora(mod)
return cls(**override_fields)
@classmethod
def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None:
state_dict = mod.load_state_dict()
if not is_state_dict_likely_flux_control(state_dict):
raise NotAMatchError("model state dict does not look like a Flux Control LoRA")
class LoRA_Diffusers_Config_Base(LoRA_Config_Base):
"""Model config for LoRA/Diffusers models."""
# TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
# the weights format. FLUX Diffusers LoRAs are single files.
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
if _get_flux_lora_format(mod):
return BaseModelType.Flux
# If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA
path_to_weight_file = cls._get_weight_file_or_raise(mod)
state_dict = mod.load_state_dict(path_to_weight_file)
token_vector_length = lora_token_vector_length(state_dict)
match token_vector_length:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized token vector length {token_vector_length}")
@classmethod
def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path:
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
for wf in weight_files:
if wf.exists():
return wf
raise NotAMatchError("missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)

View File

@@ -0,0 +1,692 @@
from abc import ABC
from typing import Any, Literal, Self
from pydantic import BaseModel, ConfigDict, Field
from invokeai.backend.model_manager.configs.base import (
Checkpoint_Config_Base,
Config_Base,
Diffusers_Config_Base,
SubmodelDefinition,
)
from invokeai.backend.model_manager.configs.clip_embed import get_clip_variant_type_from_config
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_exact,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
FluxVariantType,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
cfg_rescale_multiplier: float | None = Field(
default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
)
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model")
model_config = ConfigDict(extra="forbid")
class Main_Config_Base(ABC, BaseModel):
type: Literal[ModelType.Main] = Field(default=ModelType.Main)
trigger_phrases: set[str] | None = Field(
default=None,
description="Set of trigger phrases for this model",
)
default_settings: MainModelDefaultSettings | None = Field(
default=None,
description="Default settings for this model",
)
def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool:
bnb_nf4_keys = {
"double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
"model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
}
return any(key in state_dict for key in bnb_nf4_keys)
def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
return any(isinstance(v, GGMLTensor) for v in state_dict.values())
def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
for key in state_dict.keys():
if isinstance(key, int):
continue
elif key.startswith(
(
"cond_stage_model.",
"first_stage_model.",
"model.diffusion_model.",
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
# This prefix is typically used to distinguish between multiple models bundled in a single file.
"model.diffusion_model.double_blocks.",
)
):
return True
elif key.startswith("double_blocks.") and "ip_adapter" not in key:
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
# careful to avoid false positives on XLabs FLUX IP-Adapter models.
return True
return False
class Main_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
prediction_type: SchedulerPredictionType = Field()
variant: ModelVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_base(mod)
prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, prediction_type=prediction_type, variant=variant)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
raise NotAMatchError("unable to determine base type from state dict")
@classmethod
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType:
base = cls.model_fields["base"].default
if base is BaseModelType.StableDiffusion2:
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in state_dict:
if state_dict["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif state_dict["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
return SchedulerPredictionType.VPrediction
else:
return SchedulerPredictionType.Epsilon
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
base = cls.model_fields["base"].default
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.0.0.weight"
if key_name not in state_dict:
raise NotAMatchError("unable to determine model variant from state dict")
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
match in_channels:
case 4:
return ModelVariantType.Normal
case 5:
# Only SD2 has a depth variant
assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
return ModelVariantType.Depth
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatchError(f"unrecognized unet in_channels {in_channels} for base '{base}'")
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
# Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight".
#
# Known models that use the latter key:
# - https://civitai.com/models/885098?modelVersionId=990775
# - https://civitai.com/models/1018060?modelVersionId=1596255
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
#
# Input channels for known FLUX models:
# - Unquantized Dev and Schnell have in_channels=64
# - BNB-NF4 Dev and Schnell have in_channels=1
# - FLUX Fill has in_channels=384
# - Unsure of quantized FLUX Fill models
# - Unsure of GGUF-quantized models
in_channels = None
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
if key in state_dict:
in_channels = state_dict[key].shape[1]
break
if in_channels is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
return None
# Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of
# certain keys to distinguish between them.
is_flux_dev = (
"guidance_in.out_layer.weight" in state_dict
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
)
if is_flux_dev and in_channels == 384:
return FluxVariantType.DevFill
elif is_flux_dev:
return FluxVariantType.Dev
else:
# Must be a Schnell model...?
return FluxVariantType.Schnell
class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_is_flux(mod)
cls._validate_does_not_look_like_bnb_quantized(mod)
cls._validate_does_not_look_like_gguf_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@classmethod
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
if not state_dict_has_any_keys_exact(
mod.load_state_dict(),
{
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
},
):
raise NotAMatchError("state dict does not look like a FLUX checkpoint")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
state_dict = mod.load_state_dict()
variant = _get_flux_variant(state_dict)
if variant is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
raise NotAMatchError("unable to determine model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
@classmethod
def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
if has_bnb_nf4_keys:
raise NotAMatchError("state dict looks like bnb quantized nf4")
@classmethod
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if has_ggml_tensors:
raise NotAMatchError("state dict looks like GGUF quantized")
class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_model_looks_like_bnb_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
state_dict = mod.load_state_dict()
variant = _get_flux_variant(state_dict)
if variant is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
raise NotAMatchError("unable to determine model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
@classmethod
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
if not has_bnb_nf4_keys:
raise NotAMatchError("state dict does not look like bnb quantized nf4")
class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_looks_like_gguf_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
state_dict = mod.load_state_dict()
variant = _get_flux_variant(state_dict)
if variant is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
raise NotAMatchError("unable to determine model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
@classmethod
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if not has_ggml_tensors:
raise NotAMatchError("state dict does not look like GGUF quantized")
class Main_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
prediction_type: SchedulerPredictionType = Field()
variant: ModelVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
# SD 1.x and 2.x
"StableDiffusionPipeline",
"StableDiffusionInpaintPipeline",
# SDXL
"StableDiffusionXLPipeline",
"StableDiffusionXLInpaintPipeline",
# SDXL Refiner
"StableDiffusionXLImg2ImgPipeline",
# TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase.
"LatentConsistencyModelPipeline",
},
)
cls._validate_base(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
variant=variant,
prediction_type=prediction_type,
repo_variant=repo_variant,
)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
# Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL).
unet_conf = get_config_dict_or_raise(mod.path / "unet" / "config.json")
cross_attention_dim = unet_conf.get("cross_attention_dim")
match cross_attention_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXLRefiner
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized cross_attention_dim {cross_attention_dim}")
@classmethod
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType:
scheduler_conf = get_config_dict_or_raise(mod.path / "scheduler" / "scheduler_config.json")
# TODO(psyche): Is epsilon the right default or should we raise if it's not present?
prediction_type = scheduler_conf.get("prediction_type", "epsilon")
match prediction_type:
case "v_prediction":
return SchedulerPredictionType.VPrediction
case "epsilon":
return SchedulerPredictionType.Epsilon
case _:
raise NotAMatchError(f"unrecognized scheduler prediction_type {prediction_type}")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
base = cls.model_fields["base"].default
unet_config = get_config_dict_or_raise(mod.path / "unet" / "config.json")
in_channels = unet_config.get("in_channels")
match in_channels:
case 4:
return ModelVariantType.Normal
case 5:
# Only SD2 has a depth variant
assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
return ModelVariantType.Depth
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatchError(f"unrecognized unet in_channels {in_channels} for base '{base}'")
class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1)
class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2)
class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL)
class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner)
class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3)
submodels: dict[SubModelType, SubmodelDefinition] | None = Field(
description="Loadable submodels in this model",
default=None,
)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
# This check implies the base type - no further validation needed.
raise_for_class_name(
common_config_paths(mod.path),
{
"StableDiffusion3Pipeline",
"SD3Transformer2DModel",
},
)
submodels = override_fields.get("submodels") or cls._get_submodels_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
submodels=submodels,
repo_variant=repo_variant,
)
@classmethod
def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]:
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
config = get_config_dict_or_raise(common_config_paths(mod.path))
submodels: dict[SubModelType, SubmodelDefinition] = {}
for key, value in config.items():
# Anything that starts with an underscore is top-level metadata, not a submodel
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
continue
# The key is something like "transformer" and is a submodel - it will be in a dir of the same name.
# The value value is something like ["diffusers", "SD3Transformer2DModel"]
_library_name, class_name = value
match class_name:
case "CLIPTextModelWithProjection":
model_type = ModelType.CLIPEmbed
path_or_prefix = (mod.path / key).resolve().as_posix()
# We need to read the config to determine the variant of the CLIP model.
clip_embed_config = get_config_dict_or_raise(
{
mod.path / key / "config.json",
mod.path / key / "model_index.json",
}
)
variant = get_clip_variant_type_from_config(clip_embed_config)
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=path_or_prefix,
model_type=model_type,
variant=variant,
)
case "SD3Transformer2DModel":
model_type = ModelType.Main
path_or_prefix = (mod.path / key).resolve().as_posix()
variant = None
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=path_or_prefix,
model_type=model_type,
variant=variant,
)
case _:
pass
return submodels
class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
# This check implies the base type - no further validation needed.
raise_for_class_name(
common_config_paths(mod.path),
{
"CogView4Pipeline",
},
)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
repo_variant=repo_variant,
)
class ExternalAPI_Config_Base(ABC, BaseModel):
"""Model config for API-based models."""
format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise NotAMatchError("External API models cannot be built from disk")
class Main_ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o)
class Main_ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5)
class Main_ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3)
class Main_ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4)
class Main_ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
class Video_Config_Base(ABC, BaseModel):
type: Literal[ModelType.Video] = Field(default=ModelType.Video)
trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None)
default_settings: MainModelDefaultSettings | None = Field(
description="Default settings for this model", default=None
)
class Video_ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base):
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
class Video_ExternalAPI_Runway_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base):
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)

View File

@@ -0,0 +1,44 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
common_config_paths,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for SigLIP."""
type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"SiglipModel",
},
)
return cls(**override_fields)

View File

@@ -0,0 +1,54 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_file,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
class Spandrel_Checkpoint_Config(Config_Base):
"""Model config for Spandrel Image to Image models."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_spandrel_loads_model(mod)
return cls(**override_fields)
@classmethod
def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None:
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(mod.path)
except Exception as e:
raise NotAMatchError("model does not match SpandrelImageToImage heuristics") from e

View File

@@ -0,0 +1,79 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
default_settings: ControlAdapterDefaultSettings | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"T2IAdapter",
},
)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
config_dict = get_config_dict_or_raise(common_config_paths(mod.path))
adapter_type = config_dict.get("adapter_type")
match adapter_type:
case "full_adapter_xl":
return BaseModelType.StableDiffusionXL
case "full_adapter" | "light_adapter":
return BaseModelType.StableDiffusion1
case _:
raise NotAMatchError(f"unrecognized adapter_type '{adapter_type}'")
class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)

View File

@@ -0,0 +1,87 @@
from typing import Any, Literal, Self
from pydantic import Field
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
state_dict_has_any_keys_ending_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
class T5Encoder_T5Encoder_Config(Config_Base):
"""Configuration for T5 Encoder models in a bespoke, diffusers-like format. The model weights are expected to be in
a folder called text_encoder_2 inside the model directory, with a config file named model.safetensors.index.json."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"T5EncoderModel",
},
)
cls.raise_if_doesnt_have_unquantized_config_file(mod)
return cls(**override_fields)
@classmethod
def raise_if_doesnt_have_unquantized_config_file(cls, mod: ModelOnDisk) -> None:
has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
if not has_unquantized_config:
raise NotAMatchError("missing text_encoder_2/model.safetensors.index.json")
class T5Encoder_BnBLLMint8_Config(Config_Base):
"""Configuration for T5 Encoder models quantized by bitsandbytes' LLM.int8."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"T5EncoderModel",
},
)
cls.raise_if_filename_doesnt_look_like_bnb_quantized(mod)
cls.raise_if_state_dict_doesnt_look_like_bnb_quantized(mod)
return cls(**override_fields)
@classmethod
def raise_if_filename_doesnt_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
if not filename_looks_like_bnb:
raise NotAMatchError("filename does not look like bnb quantized llm_int8")
@classmethod
def raise_if_state_dict_doesnt_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_scb_key_suffix = state_dict_has_any_keys_ending_with(mod.load_state_dict(), "SCB")
if not has_scb_key_suffix:
raise NotAMatchError("state dict does not look like bnb quantized llm_int8")

View File

@@ -0,0 +1,156 @@
from abc import ABC
from pathlib import Path
from typing import (
Literal,
Self,
)
import torch
from pydantic import BaseModel, Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class TI_Config_Base(ABC, BaseModel):
type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion)
@classmethod
def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None:
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod, path)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
try:
p = path or mod.path
if not p.exists():
return False
if p.is_dir():
return False
if p.name in [f"learned_embeds.{s}" for s in mod.weight_files()]:
return True
state_dict = mod.load_state_dict(p)
# Heuristic: textual inversion embeddings have these keys
if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()):
return True
# Heuristic: small state dict with all tensor values
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
return True
return False
except Exception:
return False
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
p = path or mod.path
try:
state_dict = mod.load_state_dict(p)
except Exception as e:
raise NotAMatchError(f"unable to load state dict from {p}: {e}") from e
try:
if "string_to_token" in state_dict:
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
elif "emb_params" in state_dict:
token_dim = state_dict["emb_params"].shape[-1]
elif "clip_g" in state_dict:
token_dim = state_dict["clip_g"].shape[-1]
else:
token_dim = list(state_dict.values())[0].shape[0]
except Exception as e:
raise NotAMatchError(f"unable to determine token dimension from state dict in {p}: {e}") from e
match token_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized token dimension {token_dim}")
class TI_File_Config_Base(TI_Config_Base):
"""Model config for textual inversion embeddings."""
format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
if not cls._file_looks_like_embedding(mod):
raise NotAMatchError("model does not look like a textual inversion embedding file")
cls._validate_base(mod)
return cls(**override_fields)
class TI_File_SD1_Config(TI_File_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class TI_File_SD2_Config(TI_File_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class TI_Folder_Config_Base(TI_Config_Base):
"""Model config for textual inversion embeddings."""
format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
for p in mod.weight_files():
if cls._file_looks_like_embedding(mod, p):
cls._validate_base(mod, p)
return cls(**override_fields)
raise NotAMatchError("model does not look like a textual inversion embedding folder")
class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)

View File

@@ -0,0 +1,24 @@
from typing import Any, Literal, Self
from pydantic import Field
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class Unknown_Config(Config_Base):
"""Model config for unknown models, used as a fallback when we cannot identify a model."""
base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown)
format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
raise NotAMatchError("unknown model config cannot match any model")

View File

@@ -0,0 +1,166 @@
import re
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
state_dict_has_any_keys_starting_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
REGEX_TO_BASE: dict[str, BaseModelType] = {
r"xl": BaseModelType.StableDiffusionXL,
r"sd2": BaseModelType.StableDiffusion2,
r"vae": BaseModelType.StableDiffusion1,
r"FLUX.1-schnell_ae": BaseModelType.Flux,
}
class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_vae(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
if not state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
{
"encoder.conv_in",
"decoder.conv_in",
},
):
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
for regexp, base in REGEX_TO_BASE.items():
if re.search(regexp, mod.path.name, re.IGNORECASE):
return base
raise NotAMatchError("cannot determine base type")
class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"AutoencoderKL",
"AutoencoderTiny",
},
)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool:
# Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
@classmethod
def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
# Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best
# we can do is guess based on name.
return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE))
@classmethod
def _guess_name(cls, mod: ModelOnDisk) -> str:
name = mod.path.name
if name == "vae":
name = mod.path.parent.name
return name
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
config_dict = get_config_dict_or_raise(common_config_paths(mod.path))
if cls._config_looks_like_sdxl(config_dict):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
return BaseModelType.StableDiffusionXL
else:
# TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO.
return BaseModelType.StableDiffusion1
class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)