refactor(mm): add config validation utils, make it all consistent and clean

This commit is contained in:
psychedelicious
2025-09-25 21:53:05 +10:00
parent 925698a688
commit 9745c25b1b

View File

@@ -26,6 +26,7 @@ import re
import time
from abc import ABC
from enum import Enum
from functools import cache
from inspect import isabstract
from pathlib import Path
from typing import (
@@ -40,6 +41,7 @@ from typing import (
import torch
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError
from pydantic_core import CoreSchema, SchemaValidator
from typing_extensions import Annotated, Any, Dict
from invokeai.app.services.config.config_default import get_config
@@ -102,6 +104,38 @@ class NotAMatch(Exception):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
# Utility from https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
schema: CoreSchema = model.__pydantic_core_schema__.copy()
# we shallow copied, be careful not to mutate the original schema!
assert schema["type"] in ["definitions", "model"]
# find the field schema
field_schema = schema["schema"] # type: ignore
while "fields" not in field_schema:
field_schema = field_schema["schema"] # type: ignore
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
# if the original schema is a definition schema, replace the model schema with the field schema
if schema["type"] == "definitions":
schema["schema"] = field_schema
return schema
else:
return field_schema
@cache
def validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
return SchemaValidator(find_field_schema(model, field_name))
def validate_model_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
return validator(model, field_name).validate_python(value)
# These utility functions are tightly coupled to the config classes below in order to make the process of raising
# NotAMatch exceptions as easy and consistent as possible.
@@ -123,12 +157,15 @@ def _get_config_or_raise(
raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e
def _validate_class_names(
def _get_class_name_from_config(
config_class: type,
config_path: Path,
valid_class_names: set[str],
) -> None:
"""Raise NotAMatch if the config file is missing or does not contain a valid class name."""
) -> str:
"""Load the config file and return the class name.
Raises:
NotAMatch if the config file is missing or does not contain a valid class name.
"""
config = _get_config_or_raise(config_class, config_path)
@@ -142,36 +179,48 @@ def _validate_class_names(
except Exception as e:
raise NotAMatch(config_class, f"unable to determine class name from config file: {config_path}") from e
if config_class_name not in valid_class_names:
raise NotAMatch(config_class, f"model class is not one of {valid_class_names}, got {config_class_name}")
if not isinstance(config_class_name, str):
raise NotAMatch(config_class, f"_class_name or architectures field is not a string: {config_class_name}")
return config_class_name
def _validate_overrides(
config_class: type,
provided_overrides: dict[str, Any],
valid_overrides: dict[str, Any],
) -> None:
"""Check if the provided overrides match the valid overrides for this config class.
def _validate_class_name(config_class: type[BaseModel], config_path: Path, expected: set[str]) -> None:
"""Check if the class name in the config file matches the expected class names.
Args:
config_class: The config class that is being tested.
provided_overrides: The overrides provided by the user.
valid_overrides: The overrides that are valid for this config class.
config_path: The path to the config file.
expected: The expected class names."""
class_name = _get_class_name_from_config(config_class, config_path)
if class_name not in expected:
raise NotAMatch(config_class, f"invalid class name from config: {class_name}")
def _validate_override_fields(
config_class: type[BaseModel],
override_fields: dict[str, Any],
) -> None:
"""Check if the provided override fields are valid for the config class.
Args:
config_class: The config class that is being tested.
override_fields: The override fields provided by the user.
Raises:
NotAMatch if any override does not match the allowed value.
NotAMatch if any override field is invalid for the config.
"""
for key, value in valid_overrides.items():
if key not in provided_overrides:
continue
if provided_overrides[key] != value:
raise NotAMatch(
config_class,
f"override {key}={provided_overrides[key]} does not match required value {key}={value}",
)
for field_name, override_value in override_fields.items():
if field_name not in config_class.model_fields:
raise NotAMatch(config_class, f"unknown override field: {field_name}")
try:
validate_model_field(config_class, field_name, override_value)
except ValidationError as e:
raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e
def _raise_if_not_file(
def _validate_is_file(
config_class: type,
mod: ModelOnDisk,
) -> None:
@@ -180,7 +229,7 @@ def _raise_if_not_file(
raise NotAMatch(config_class, "model path is not a file")
def _raise_if_not_dir(
def _validate_is_dir(
config_class: type,
mod: ModelOnDisk,
) -> None:
@@ -346,7 +395,10 @@ class UnknownModelConfig(ModelConfigBase):
class CheckpointConfigBase(ABC, BaseModel):
"""Base class for checkpoint-style models."""
config_path: str | None = Field(None, description="Path to the config for this model, if any.")
config_path: str | None = Field(
description="Path to the config for this model, if any.",
default=None,
)
converted_at: float | None = Field(
description="When this model was last converted to diffusers",
default_factory=time.time,
@@ -365,66 +417,57 @@ class T5EncoderConfig(ModelConfigBase):
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.T5Encoder,
"format": ModelFormat.T5Encoder,
}
VALID_CLASS_NAMES: ClassVar = {
"T5EncoderModel",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
_validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"})
# Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models)
cls._validate_has_unquantized_config_file(mod)
return cls(**fields)
@classmethod
def _validate_has_unquantized_config_file(cls, mod: ModelOnDisk) -> None:
has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
if not has_unquantized_config:
raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json")
return cls(**fields)
class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.T5Encoder,
"format": ModelFormat.BnbQuantizedLlmInt8b,
}
VALID_CLASS_NAMES: ClassVar = {
"T5EncoderModel",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
# Heuristic: Look for the T5EncoderModel class name in the config
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
_validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"})
# Heuristic: look for the quantization in the filename name
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
cls._validate_filename_looks_like_bnb_quantized(mod)
# Heuristic: Look for the presence of "SCB" suffixes in state dict keys
has_scb_key_suffix = mod.has_keys_ending_with("SCB")
if not filename_looks_like_bnb and not has_scb_key_suffix:
raise NotAMatch(cls, "missing bnb quantization indicators")
cls._validate_model_looks_like_bnb_quantized(mod)
return cls(**fields)
@classmethod
def _validate_filename_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
if not filename_looks_like_bnb:
raise NotAMatch(cls, "filename does not look like bnb quantized llm_int8")
@classmethod
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_scb_key_suffix = mod.has_keys_ending_with("SCB")
if not has_scb_key_suffix:
raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
class LoRAConfigBase(ABC, BaseModel):
"""Base class for LoRA models."""
@@ -453,36 +496,40 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
base: Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL] = Field()
format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LoRA,
"format": ModelFormat.OMI,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
# OMI LoRAs are always files
_raise_if_not_file(cls, mod)
_validate_is_file(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
cls._validate_is_not_controllora_or_diffusers(mod)
# Heuristic: Look for OMI LoRA metadata
cls._validate_metadata_looks_like_omi(mod)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA."""
flux_format = get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
@classmethod
def _validate_metadata_looks_like_omi(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model metadata does not look like an OMI LoRA."""
metadata = mod.metadata()
is_omi_lora_heuristic = (
metadata_looks_like_omi_lora = (
bool(metadata.get("modelspec.sai_model_spec"))
and metadata.get("ot_branch") == "omi_format"
and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
)
if not is_omi_lora_heuristic:
raise NotAMatch(cls, "model does not match OMI LoRA heuristics")
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
if not metadata_looks_like_omi_lora:
raise NotAMatch(cls, "metadata does not look like OMI LoRA")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]:
@@ -512,21 +559,20 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LoRA,
"format": ModelFormat.LyCORIS,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
_validate_is_file(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
cls._validate_is_not_controllora_or_diffusers(mod)
cls._validate_looks_like_lora(mod)
return cls(**fields)
@classmethod
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = mod.has_keys_starting_with(
@@ -551,7 +597,12 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
if not has_key_with_lora_prefix and not has_key_with_lora_suffix:
raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics")
return cls(**fields)
@classmethod
def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA."""
flux_format = get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> LoRALyCORIS_SupportedBases:
@@ -591,15 +642,21 @@ class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
_validate_is_file(cls, mod)
_validate_override_fields(cls, fields)
cls._validate_looks_like_control_lora(mod)
return cls(**fields)
@classmethod
def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None:
state_dict = mod.load_state_dict()
if not is_state_dict_likely_flux_control(state_dict):
raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
return cls(**fields)
ControlLoRADiffusers_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
@@ -618,28 +675,31 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
base: LoRADiffusers_SupportedBases = Field()
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LoRA,
"format": ModelFormat.Diffusers,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
# Diffusers-style models always directories
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
is_flux_lora_diffusers = get_flux_lora_format(mod) is FluxLoRAFormat.Diffusers
cls._validate_looks_like_diffusers_lora(mod)
cls._validate_has_lora_weight_file(mod)
return cls(**fields)
@classmethod
def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None:
flux_lora_format = get_flux_lora_format(mod)
if flux_lora_format is not FluxLoRAFormat.Diffusers:
raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA")
@classmethod
def _validate_has_lora_weight_file(cls, mod: ModelOnDisk) -> None:
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
has_lora_weight_file = any(wf.exists() for wf in weight_files)
if not is_flux_lora_diffusers and not has_lora_weight_file:
raise NotAMatch(cls, "model does not match Diffusers LoRA heuristics")
return cls(**fields)
if not has_lora_weight_file:
raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
VAECheckpointConfig_SupportedBases: TypeAlias = Literal[
@@ -657,11 +717,6 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.VAE,
"format": ModelFormat.Checkpoint,
}
REGEX_TO_BASE: ClassVar[dict[str, VAECheckpointConfig_SupportedBases]] = {
r"xl": BaseModelType.StableDiffusionXL,
r"sd2": BaseModelType.StableDiffusion2,
@@ -671,16 +726,20 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
_validate_is_file(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
cls._validate_looks_like_vae(mod)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAECheckpointConfig_SupportedBases:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
@@ -704,22 +763,13 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.VAE,
"format": ModelFormat.Diffusers,
}
VALID_CLASS_NAMES: ClassVar = {
"AutoencoderKL",
"AutoencoderTiny",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
_validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES)
_validate_class_name(cls, mod.path / "config.json", {"AutoencoderKL", "AutoencoderTiny"})
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@@ -770,23 +820,13 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.ControlNet,
"format": ModelFormat.Diffusers,
}
VALID_CLASS_NAMES: ClassVar = {
"ControlNetModel",
"FluxControlNetModel",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
_validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES)
_validate_class_name(cls, mod.path / "config.json", {"ControlNetModel", "FluxControlNetModel"})
base = fields.get("base") or cls._get_base_or_raise(mod)
@@ -829,17 +869,20 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.ControlNet,
"format": ModelFormat.Checkpoint,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
_validate_is_file(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
cls._validate_looks_like_controlnet(mod)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
if not mod.has_keys_starting_with(
{
"controlnet",
@@ -855,10 +898,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
):
raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint")
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetCheckpoint_SupportedBases:
state_dict = mod.load_state_dict()
@@ -970,20 +1009,11 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.TextualInversion,
"format": ModelFormat.EmbeddingFile,
}
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
_validate_is_file(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
if not cls._file_looks_like_embedding(mod):
raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
@@ -997,20 +1027,11 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.TextualInversion,
"format": ModelFormat.EmbeddingFolder,
}
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
for p in mod.weight_files():
if cls._file_looks_like_embedding(mod, p):
@@ -1105,28 +1126,31 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str = Field()
VALID_OVERRIDES: ClassVar = {
"type": ModelType.IPAdapter,
"format": ModelFormat.InvokeAI,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
cls._validate_has_weights_file(mod)
cls._validate_has_image_encoder_metadata_file(mod)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None:
weights_file = mod.path / "ip_adapter.bin"
if not weights_file.exists():
raise NotAMatch(cls, "missing ip_adapter.bin weights file")
@classmethod
def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None:
image_encoder_metadata_file = mod.path / "image_encoder.txt"
if not image_encoder_metadata_file.exists():
raise NotAMatch(cls, "missing image_encoder.txt metadata file")
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAI_SupportedBases:
state_dict = mod.load_state_dict()
@@ -1161,17 +1185,19 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
base: IPAdapterCheckpoint_SupportedBases = Field()
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.IPAdapter,
"format": ModelFormat.Checkpoint,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
_validate_is_file(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
cls._validate_looks_like_ip_adapter(mod)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
if not mod.has_keys_starting_with(
{
"image_proj.",
@@ -1182,9 +1208,6 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
):
raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics")
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpoint_SupportedBases:
state_dict = mod.load_state_dict()
@@ -1215,14 +1238,8 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_CLASS_NAMES: ClassVar = {
"CLIPModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
}
@classmethod
def get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None:
def _get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None:
try:
hidden_size = config.get("hidden_size")
match hidden_size:
@@ -1241,69 +1258,64 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.CLIPEmbed,
"format": ModelFormat.Diffusers,
"variant": ClipVariantType.G,
}
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
config_path = mod.path / "config.json"
_validate_class_name(
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
)
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
cls._validate_clip_g_variant(mod)
config = _get_config_or_raise(cls, config_path)
return cls(**fields)
clip_variant = cls.get_clip_variant_type(config)
@classmethod
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
config = _get_config_or_raise(cls, mod.path / "config.json")
clip_variant = cls._get_clip_variant_type(config)
if clip_variant is not ClipVariantType.G:
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
return cls(**fields)
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-L Embeddings."""
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.CLIPEmbed,
"format": ModelFormat.Diffusers,
"variant": ClipVariantType.L,
}
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
config_path = mod.path / "config.json"
_validate_class_name(
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
)
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
config = _get_config_or_raise(cls, config_path)
clip_variant = cls.get_clip_variant_type(config)
if clip_variant is not ClipVariantType.L:
raise NotAMatch(cls, "model does not match CLIP-L heuristics")
cls._validate_clip_l_variant(mod)
return cls(**fields)
@classmethod
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
config = _get_config_or_raise(cls, mod.path / "config.json")
clip_variant = cls._get_clip_variant_type(config)
if clip_variant is not ClipVariantType.L:
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for CLIPVision."""
@@ -1312,24 +1324,13 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.CLIPVision,
"format": ModelFormat.Diffusers,
}
VALID_CLASS_NAMES: ClassVar = {
"CLIPVisionModelWithProjection",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
config_path = mod.path / "config.json"
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
_validate_class_name(cls, mod.path / "config.json", {"CLIPVisionModelWithProjection"})
return cls(**fields)
@@ -1347,24 +1348,13 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.T2IAdapter,
"format": ModelFormat.Diffusers,
}
VALID_CLASS_NAMES: ClassVar = {
"T2IAdapter",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
config_path = mod.path / "config.json"
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
_validate_class_name(cls, mod.path / "config.json", {"T2IAdapter"})
base = fields.get("base") or cls._get_base_or_raise(mod)
@@ -1392,17 +1382,18 @@ class SpandrelImageToImageConfig(ModelConfigBase):
type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.SpandrelImageToImage,
"format": ModelFormat.Checkpoint,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
_validate_is_file(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
cls._validate_spandrel_loads_model(mod)
return cls(**fields)
@classmethod
def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None:
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
@@ -1413,7 +1404,6 @@ class SpandrelImageToImageConfig(ModelConfigBase):
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(mod.path)
return cls(**fields)
except Exception as e:
raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e
@@ -1425,24 +1415,13 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.SigLIP,
"format": ModelFormat.Diffusers,
}
VALID_CLASS_NAMES: ClassVar = {
"SiglipModel",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
config_path = mod.path / "config.json"
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
_validate_class_name(cls, mod.path / "config.json", {"SiglipModel"})
return cls(**fields)
@@ -1454,16 +1433,11 @@ class FluxReduxConfig(ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.FluxRedux,
"format": ModelFormat.Checkpoint,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
_validate_is_file(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
if not is_state_dict_likely_flux_redux(mod.load_state_dict()):
raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics")
@@ -1478,24 +1452,13 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
variant: Literal[ModelVariantType.Normal] = Field(default=ModelVariantType.Normal)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LlavaOnevision,
"format": ModelFormat.Diffusers,
}
VALID_CLASS_NAMES: ClassVar = {
"LlavaOnevisionForConditionalGeneration",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_is_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_override_fields(cls, fields)
config_path = mod.path / "config.json"
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
_validate_class_name(cls, mod.path / "config.json", {"LlavaOnevisionForConditionalGeneration"})
return cls(**fields)