From f9686b38fa0769cd714193f8277ecbd1c9f9f31a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 25 Sep 2025 21:53:05 +1000 Subject: [PATCH] refactor(mm): add config validation utils, make it all consistent and clean --- invokeai/backend/model_manager/config.py | 569 +++++++++++------------ 1 file changed, 266 insertions(+), 303 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 118bec28cb..805202ef9b 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -26,6 +26,7 @@ import re import time from abc import ABC from enum import Enum +from functools import cache from inspect import isabstract from pathlib import Path from typing import ( @@ -40,6 +41,7 @@ from typing import ( import torch from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError +from pydantic_core import CoreSchema, SchemaValidator from typing_extensions import Annotated, Any, Dict from invokeai.app.services.config.config_default import get_config @@ -102,6 +104,38 @@ class NotAMatch(Exception): DEFAULTS_PRECISION = Literal["fp16", "fp32"] + +# Utility from https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144 +def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema: + schema: CoreSchema = model.__pydantic_core_schema__.copy() + # we shallow copied, be careful not to mutate the original schema! + + assert schema["type"] in ["definitions", "model"] + + # find the field schema + field_schema = schema["schema"] # type: ignore + while "fields" not in field_schema: + field_schema = field_schema["schema"] # type: ignore + + field_schema = field_schema["fields"][field_name]["schema"] # type: ignore + + # if the original schema is a definition schema, replace the model schema with the field schema + if schema["type"] == "definitions": + schema["schema"] = field_schema + return schema + else: + return field_schema + + +@cache +def validator(model: type[BaseModel], field_name: str) -> SchemaValidator: + return SchemaValidator(find_field_schema(model, field_name)) + + +def validate_model_field(model: type[BaseModel], field_name: str, value: Any) -> Any: + return validator(model, field_name).validate_python(value) + + # These utility functions are tightly coupled to the config classes below in order to make the process of raising # NotAMatch exceptions as easy and consistent as possible. @@ -123,12 +157,15 @@ def _get_config_or_raise( raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e -def _validate_class_names( +def _get_class_name_from_config( config_class: type, config_path: Path, - valid_class_names: set[str], -) -> None: - """Raise NotAMatch if the config file is missing or does not contain a valid class name.""" +) -> str: + """Load the config file and return the class name. + + Raises: + NotAMatch if the config file is missing or does not contain a valid class name. + """ config = _get_config_or_raise(config_class, config_path) @@ -142,36 +179,48 @@ def _validate_class_names( except Exception as e: raise NotAMatch(config_class, f"unable to determine class name from config file: {config_path}") from e - if config_class_name not in valid_class_names: - raise NotAMatch(config_class, f"model class is not one of {valid_class_names}, got {config_class_name}") + if not isinstance(config_class_name, str): + raise NotAMatch(config_class, f"_class_name or architectures field is not a string: {config_class_name}") + + return config_class_name -def _validate_overrides( - config_class: type, - provided_overrides: dict[str, Any], - valid_overrides: dict[str, Any], -) -> None: - """Check if the provided overrides match the valid overrides for this config class. +def _validate_class_name(config_class: type[BaseModel], config_path: Path, expected: set[str]) -> None: + """Check if the class name in the config file matches the expected class names. Args: config_class: The config class that is being tested. - provided_overrides: The overrides provided by the user. - valid_overrides: The overrides that are valid for this config class. + config_path: The path to the config file. + expected: The expected class names.""" + + class_name = _get_class_name_from_config(config_class, config_path) + if class_name not in expected: + raise NotAMatch(config_class, f"invalid class name from config: {class_name}") + + +def _validate_override_fields( + config_class: type[BaseModel], + override_fields: dict[str, Any], +) -> None: + """Check if the provided override fields are valid for the config class. + + Args: + config_class: The config class that is being tested. + override_fields: The override fields provided by the user. Raises: - NotAMatch if any override does not match the allowed value. + NotAMatch if any override field is invalid for the config. """ - for key, value in valid_overrides.items(): - if key not in provided_overrides: - continue - if provided_overrides[key] != value: - raise NotAMatch( - config_class, - f"override {key}={provided_overrides[key]} does not match required value {key}={value}", - ) + for field_name, override_value in override_fields.items(): + if field_name not in config_class.model_fields: + raise NotAMatch(config_class, f"unknown override field: {field_name}") + try: + validate_model_field(config_class, field_name, override_value) + except ValidationError as e: + raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e -def _raise_if_not_file( +def _validate_is_file( config_class: type, mod: ModelOnDisk, ) -> None: @@ -180,7 +229,7 @@ def _raise_if_not_file( raise NotAMatch(config_class, "model path is not a file") -def _raise_if_not_dir( +def _validate_is_dir( config_class: type, mod: ModelOnDisk, ) -> None: @@ -346,7 +395,10 @@ class UnknownModelConfig(ModelConfigBase): class CheckpointConfigBase(ABC, BaseModel): """Base class for checkpoint-style models.""" - config_path: str | None = Field(None, description="Path to the config for this model, if any.") + config_path: str | None = Field( + description="Path to the config for this model, if any.", + default=None, + ) converted_at: float | None = Field( description="When this model was last converted to diffusers", default_factory=time.time, @@ -365,66 +417,57 @@ class T5EncoderConfig(ModelConfigBase): type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.T5Encoder, - "format": ModelFormat.T5Encoder, - } - - VALID_CLASS_NAMES: ClassVar = { - "T5EncoderModel", - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - _validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES) + _validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"}) - # Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models) + cls._validate_has_unquantized_config_file(mod) + + return cls(**fields) + + @classmethod + def _validate_has_unquantized_config_file(cls, mod: ModelOnDisk) -> None: has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists() if not has_unquantized_config: raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json") - return cls(**fields) - class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase): base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.T5Encoder, - "format": ModelFormat.BnbQuantizedLlmInt8b, - } - - VALID_CLASS_NAMES: ClassVar = { - "T5EncoderModel", - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - # Heuristic: Look for the T5EncoderModel class name in the config - _validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES) + _validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"}) - # Heuristic: look for the quantization in the filename name - filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) + cls._validate_filename_looks_like_bnb_quantized(mod) - # Heuristic: Look for the presence of "SCB" suffixes in state dict keys - has_scb_key_suffix = mod.has_keys_ending_with("SCB") - - if not filename_looks_like_bnb and not has_scb_key_suffix: - raise NotAMatch(cls, "missing bnb quantization indicators") + cls._validate_model_looks_like_bnb_quantized(mod) return cls(**fields) + @classmethod + def _validate_filename_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) + if not filename_looks_like_bnb: + raise NotAMatch(cls, "filename does not look like bnb quantized llm_int8") + + @classmethod + def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_scb_key_suffix = mod.has_keys_ending_with("SCB") + if not has_scb_key_suffix: + raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8") + class LoRAConfigBase(ABC, BaseModel): """Base class for LoRA models.""" @@ -453,36 +496,40 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): base: Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL] = Field() format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.LoRA, - "format": ModelFormat.OMI, - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - # OMI LoRAs are always files - _raise_if_not_file(cls, mod) + _validate_is_file(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - # Heuristic: differential diagnosis vs ControlLoRA and Diffusers - if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA") + cls._validate_is_not_controllora_or_diffusers(mod) - # Heuristic: Look for OMI LoRA metadata + cls._validate_metadata_looks_like_omi(mod) + + base = fields.get("base") or cls._get_base_or_raise(mod) + + return cls(**fields, base=base) + + @classmethod + def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA.""" + flux_format = get_flux_lora_format(mod) + if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA") + + @classmethod + def _validate_metadata_looks_like_omi(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model metadata does not look like an OMI LoRA.""" metadata = mod.metadata() - is_omi_lora_heuristic = ( + + metadata_looks_like_omi_lora = ( bool(metadata.get("modelspec.sai_model_spec")) and metadata.get("ot_branch") == "omi_format" and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora" ) - if not is_omi_lora_heuristic: - raise NotAMatch(cls, "model does not match OMI LoRA heuristics") - - base = fields.get("base") or cls._get_base_or_raise(mod) - - return cls(**fields, base=base) + if not metadata_looks_like_omi_lora: + raise NotAMatch(cls, "metadata does not look like OMI LoRA") @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]: @@ -512,21 +559,20 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.LoRA, - "format": ModelFormat.LyCORIS, - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_file(cls, mod) + _validate_is_file(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - # Heuristic: differential diagnosis vs ControlLoRA and Diffusers - if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA") + cls._validate_is_not_controllora_or_diffusers(mod) + cls._validate_looks_like_lora(mod) + + return cls(**fields) + + @classmethod + def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. # Some main models have these keys, likely due to the creator merging in a LoRA. has_key_with_lora_prefix = mod.has_keys_starting_with( @@ -551,7 +597,12 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): if not has_key_with_lora_prefix and not has_key_with_lora_suffix: raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics") - return cls(**fields) + @classmethod + def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA.""" + flux_format = get_flux_lora_format(mod) + if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA") @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> LoRALyCORIS_SupportedBases: @@ -591,15 +642,21 @@ class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_file(cls, mod) + _validate_is_file(cls, mod) + _validate_override_fields(cls, fields) + + cls._validate_looks_like_control_lora(mod) + + return cls(**fields) + + @classmethod + def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None: state_dict = mod.load_state_dict() if not is_state_dict_likely_flux_control(state_dict): raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA") - return cls(**fields) - ControlLoRADiffusers_SupportedBases: TypeAlias = Literal[BaseModelType.Flux] @@ -618,28 +675,31 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): base: LoRADiffusers_SupportedBases = Field() format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.LoRA, - "format": ModelFormat.Diffusers, - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - # Diffusers-style models always directories - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - is_flux_lora_diffusers = get_flux_lora_format(mod) is FluxLoRAFormat.Diffusers + cls._validate_looks_like_diffusers_lora(mod) + cls._validate_has_lora_weight_file(mod) + + return cls(**fields) + + @classmethod + def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None: + flux_lora_format = get_flux_lora_format(mod) + if flux_lora_format is not FluxLoRAFormat.Diffusers: + raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA") + + @classmethod + def _validate_has_lora_weight_file(cls, mod: ModelOnDisk) -> None: suffixes = ["bin", "safetensors"] weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] has_lora_weight_file = any(wf.exists() for wf in weight_files) - - if not is_flux_lora_diffusers and not has_lora_weight_file: - raise NotAMatch(cls, "model does not match Diffusers LoRA heuristics") - - return cls(**fields) + if not has_lora_weight_file: + raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") VAECheckpointConfig_SupportedBases: TypeAlias = Literal[ @@ -657,11 +717,6 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase): type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.VAE, - "format": ModelFormat.Checkpoint, - } - REGEX_TO_BASE: ClassVar[dict[str, VAECheckpointConfig_SupportedBases]] = { r"xl": BaseModelType.StableDiffusionXL, r"sd2": BaseModelType.StableDiffusion2, @@ -671,16 +726,20 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_file(cls, mod) + _validate_is_file(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}): - raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") + cls._validate_looks_like_vae(mod) base = fields.get("base") or cls._get_base_or_raise(mod) return cls(**fields, base=base) + @classmethod + def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None: + if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}): + raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") + @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAECheckpointConfig_SupportedBases: # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name @@ -704,22 +763,13 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase): type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.VAE, - "format": ModelFormat.Diffusers, - } - VALID_CLASS_NAMES: ClassVar = { - "AutoencoderKL", - "AutoencoderTiny", - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - _validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES) + _validate_class_name(cls, mod.path / "config.json", {"AutoencoderKL", "AutoencoderTiny"}) base = fields.get("base") or cls._get_base_or_raise(mod) return cls(**fields, base=base) @@ -770,23 +820,13 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.ControlNet, - "format": ModelFormat.Diffusers, - } - - VALID_CLASS_NAMES: ClassVar = { - "ControlNetModel", - "FluxControlNetModel", - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - _validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES) + _validate_class_name(cls, mod.path / "config.json", {"ControlNetModel", "FluxControlNetModel"}) base = fields.get("base") or cls._get_base_or_raise(mod) @@ -829,17 +869,20 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.ControlNet, - "format": ModelFormat.Checkpoint, - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_file(cls, mod) + _validate_is_file(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) + cls._validate_looks_like_controlnet(mod) + + base = fields.get("base") or cls._get_base_or_raise(mod) + + return cls(**fields, base=base) + + @classmethod + def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None: if not mod.has_keys_starting_with( { "controlnet", @@ -855,10 +898,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, ): raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint") - base = fields.get("base") or cls._get_base_or_raise(mod) - - return cls(**fields, base=base) - @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetCheckpoint_SupportedBases: state_dict = mod.load_state_dict() @@ -970,20 +1009,11 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.TextualInversion, - "format": ModelFormat.EmbeddingFile, - } - - @classmethod - def get_tag(cls) -> Tag: - return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}") - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_file(cls, mod) + _validate_is_file(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) if not cls._file_looks_like_embedding(mod): raise NotAMatch(cls, "model does not look like a textual inversion embedding file") @@ -997,20 +1027,11 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.TextualInversion, - "format": ModelFormat.EmbeddingFolder, - } - - @classmethod - def get_tag(cls) -> Tag: - return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}") - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) for p in mod.weight_files(): if cls._file_looks_like_embedding(mod, p): @@ -1105,28 +1126,31 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): # time. Need to go through the history to make sure I'm understanding this fully. image_encoder_model_id: str = Field() - VALID_OVERRIDES: ClassVar = { - "type": ModelType.IPAdapter, - "format": ModelFormat.InvokeAI, - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) + cls._validate_has_weights_file(mod) + + cls._validate_has_image_encoder_metadata_file(mod) + + base = fields.get("base") or cls._get_base_or_raise(mod) + return cls(**fields, base=base) + + @classmethod + def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None: weights_file = mod.path / "ip_adapter.bin" if not weights_file.exists(): raise NotAMatch(cls, "missing ip_adapter.bin weights file") + @classmethod + def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None: image_encoder_metadata_file = mod.path / "image_encoder.txt" if not image_encoder_metadata_file.exists(): raise NotAMatch(cls, "missing image_encoder.txt metadata file") - base = fields.get("base") or cls._get_base_or_raise(mod) - return cls(**fields, base=base) - @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAI_SupportedBases: state_dict = mod.load_state_dict() @@ -1161,17 +1185,19 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): base: IPAdapterCheckpoint_SupportedBases = Field() format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.IPAdapter, - "format": ModelFormat.Checkpoint, - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_file(cls, mod) + _validate_is_file(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) + cls._validate_looks_like_ip_adapter(mod) + + base = fields.get("base") or cls._get_base_or_raise(mod) + return cls(**fields, base=base) + + @classmethod + def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None: if not mod.has_keys_starting_with( { "image_proj.", @@ -1182,9 +1208,6 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): ): raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics") - base = fields.get("base") or cls._get_base_or_raise(mod) - return cls(**fields, base=base) - @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpoint_SupportedBases: state_dict = mod.load_state_dict() @@ -1215,14 +1238,8 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase): type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - VALID_CLASS_NAMES: ClassVar = { - "CLIPModel", - "CLIPTextModel", - "CLIPTextModelWithProjection", - } - @classmethod - def get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None: + def _get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None: try: hidden_size = config.get("hidden_size") match hidden_size: @@ -1241,69 +1258,64 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.CLIPEmbed, - "format": ModelFormat.Diffusers, - "variant": ClipVariantType.G, - } - @classmethod def get_tag(cls) -> Tag: return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}") @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - config_path = mod.path / "config.json" + _validate_class_name( + cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"} + ) - _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) + cls._validate_clip_g_variant(mod) - config = _get_config_or_raise(cls, config_path) + return cls(**fields) - clip_variant = cls.get_clip_variant_type(config) + @classmethod + def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None: + config = _get_config_or_raise(cls, mod.path / "config.json") + clip_variant = cls._get_clip_variant_type(config) if clip_variant is not ClipVariantType.G: raise NotAMatch(cls, "model does not match CLIP-G heuristics") - return cls(**fields) - class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): """Model config for CLIP-L Embeddings.""" variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.CLIPEmbed, - "format": ModelFormat.Diffusers, - "variant": ClipVariantType.L, - } - @classmethod def get_tag(cls) -> Tag: return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}") @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - config_path = mod.path / "config.json" + _validate_class_name( + cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"} + ) - _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) - - config = _get_config_or_raise(cls, config_path) - clip_variant = cls.get_clip_variant_type(config) - - if clip_variant is not ClipVariantType.L: - raise NotAMatch(cls, "model does not match CLIP-L heuristics") + cls._validate_clip_l_variant(mod) return cls(**fields) + @classmethod + def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None: + config = _get_config_or_raise(cls, mod.path / "config.json") + clip_variant = cls._get_clip_variant_type(config) + + if clip_variant is not ClipVariantType.L: + raise NotAMatch(cls, "model does not match CLIP-G heuristics") + class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): """Model config for CLIPVision.""" @@ -1312,24 +1324,13 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.CLIPVision, - "format": ModelFormat.Diffusers, - } - - VALID_CLASS_NAMES: ClassVar = { - "CLIPVisionModelWithProjection", - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - config_path = mod.path / "config.json" - - _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) + _validate_class_name(cls, mod.path / "config.json", {"CLIPVisionModelWithProjection"}) return cls(**fields) @@ -1347,24 +1348,13 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.T2IAdapter, - "format": ModelFormat.Diffusers, - } - - VALID_CLASS_NAMES: ClassVar = { - "T2IAdapter", - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - config_path = mod.path / "config.json" - - _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) + _validate_class_name(cls, mod.path / "config.json", {"T2IAdapter"}) base = fields.get("base") or cls._get_base_or_raise(mod) @@ -1392,17 +1382,18 @@ class SpandrelImageToImageConfig(ModelConfigBase): type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage) format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.SpandrelImageToImage, - "format": ModelFormat.Checkpoint, - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_file(cls, mod) + _validate_is_file(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) + cls._validate_spandrel_loads_model(mod) + + return cls(**fields) + + @classmethod + def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None: try: # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were # explored to avoid this: @@ -1413,7 +1404,6 @@ class SpandrelImageToImageConfig(ModelConfigBase): # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to # maintain it, and the risk of false positive detections is higher. SpandrelImageToImageModel.load_from_file(mod.path) - return cls(**fields) except Exception as e: raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e @@ -1425,24 +1415,13 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase): format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.SigLIP, - "format": ModelFormat.Diffusers, - } - - VALID_CLASS_NAMES: ClassVar = { - "SiglipModel", - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - config_path = mod.path / "config.json" - - _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) + _validate_class_name(cls, mod.path / "config.json", {"SiglipModel"}) return cls(**fields) @@ -1454,16 +1433,11 @@ class FluxReduxConfig(ModelConfigBase): format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.FluxRedux, - "format": ModelFormat.Checkpoint, - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_file(cls, mod) + _validate_is_file(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) if not is_state_dict_likely_flux_redux(mod.load_state_dict()): raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics") @@ -1478,24 +1452,13 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) variant: Literal[ModelVariantType.Normal] = Field(default=ModelVariantType.Normal) - VALID_OVERRIDES: ClassVar = { - "type": ModelType.LlavaOnevision, - "format": ModelFormat.Diffusers, - } - - VALID_CLASS_NAMES: ClassVar = { - "LlavaOnevisionForConditionalGeneration", - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _raise_if_not_dir(cls, mod) + _validate_is_dir(cls, mod) - _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + _validate_override_fields(cls, fields) - config_path = mod.path / "config.json" - - _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) + _validate_class_name(cls, mod.path / "config.json", {"LlavaOnevisionForConditionalGeneration"}) return cls(**fields)