From dd7a51b3516014295fbe9b085f67b3e69e607cc8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 1 Oct 2025 12:21:19 +1000 Subject: [PATCH] tidy(mm): clean up model heuristic utils --- invokeai/backend/model_manager/config.py | 177 ++++++++++++------ .../backend/model_manager/model_on_disk.py | 23 --- 2 files changed, 124 insertions(+), 76 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 57f52b1045..2c9ad226a2 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -106,35 +106,67 @@ class NotAMatch(Exception): DEFAULTS_PRECISION = Literal["fp16", "fp32"] -# Utility from https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144 -def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema: - schema: CoreSchema = model.__pydantic_core_schema__.copy() - # we shallow copied, be careful not to mutate the original schema! +class FieldValidator: + """Utility class for validating individual fields of a Pydantic model without instantiating the whole model. - assert schema["type"] in ["definitions", "model"] + See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144 + """ - # find the field schema - field_schema = schema["schema"] # type: ignore - while "fields" not in field_schema: - field_schema = field_schema["schema"] # type: ignore + @staticmethod + def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema: + """Find the Pydantic core schema for a specific field in a model.""" + schema: CoreSchema = model.__pydantic_core_schema__.copy() + # we shallow copied, be careful not to mutate the original schema! - field_schema = field_schema["fields"][field_name]["schema"] # type: ignore + assert schema["type"] in ["definitions", "model"] - # if the original schema is a definition schema, replace the model schema with the field schema - if schema["type"] == "definitions": - schema["schema"] = field_schema - return schema - else: - return field_schema + # find the field schema + field_schema = schema["schema"] # type: ignore + while "fields" not in field_schema: + field_schema = field_schema["schema"] # type: ignore + + field_schema = field_schema["fields"][field_name]["schema"] # type: ignore + + # if the original schema is a definition schema, replace the model schema with the field schema + if schema["type"] == "definitions": + schema["schema"] = field_schema + return schema + else: + return field_schema + + @cache + @staticmethod + def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator: + """Get a SchemaValidator for a specific field in a model.""" + return SchemaValidator(FieldValidator.find_field_schema(model, field_name)) + + @staticmethod + def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any: + """Validate a value for a specific field in a model.""" + return FieldValidator.get_validator(model, field_name).validate_python(value) -@cache -def validator(model: type[BaseModel], field_name: str) -> SchemaValidator: - return SchemaValidator(find_field_schema(model, field_name)) +def has_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool: + """Returns true if the state dict has all of the specified keys.""" + _keys = {keys} if isinstance(keys, str) else keys + return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)}) -def validate_model_field(model: type[BaseModel], field_name: str, value: Any) -> Any: - return validator(model, field_name).validate_python(value) +def has_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool: + """Returns true if the state dict has any keys starting with any of the specified prefixes.""" + _prefixes = {prefixes} if isinstance(prefixes, str) else prefixes + return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)) + + +def has_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool: + """Returns true if the state dict has any keys ending with any of the specified suffixes.""" + _suffixes = {suffixes} if isinstance(suffixes, str) else suffixes + return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)) + + +def common_config_paths(path: Path) -> set[Path]: + """Returns common config file paths for models stored in directories.""" + return {path / "config.json", path / "model_index.json"} # These utility functions are tightly coupled to the config classes below in order to make the process of raising @@ -225,7 +257,7 @@ def _validate_override_fields( if field_name not in config_class.model_fields: raise NotAMatch(config_class, f"unknown override field: {field_name}") try: - validate_model_field(config_class, field_name, override_value) + FieldValidator.validate_field(config_class, field_name, override_value) except ValidationError as e: raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e @@ -440,7 +472,13 @@ class T5EncoderConfig(ModelConfigBase): _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"}) + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "T5EncoderModel", + }, + ) cls._validate_has_unquantized_config_file(mod) @@ -465,7 +503,13 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase): _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"}) + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "T5EncoderModel", + }, + ) cls._validate_filename_looks_like_bnb_quantized(mod) @@ -481,7 +525,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase): @classmethod def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: - has_scb_key_suffix = mod.has_keys_ending_with("SCB") + has_scb_key_suffix = has_keys_ending_with(mod.load_state_dict(), "SCB") if not has_scb_key_suffix: raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8") @@ -592,23 +636,25 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. # Some main models have these keys, likely due to the creator merging in a LoRA. - has_key_with_lora_prefix = mod.has_keys_starting_with( + has_key_with_lora_prefix = has_keys_starting_with( + mod.load_state_dict(), { "lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_", - } + }, ) - has_key_with_lora_suffix = mod.has_keys_ending_with( + has_key_with_lora_suffix = has_keys_ending_with( + mod.load_state_dict(), { "to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight", - } + }, ) if not has_key_with_lora_prefix and not has_key_with_lora_suffix: @@ -754,7 +800,13 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase): @classmethod def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None: - if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}): + if not has_keys_starting_with( + mod.load_state_dict(), + { + "encoder.conv_in", + "decoder.conv_in", + }, + ): raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") @classmethod @@ -786,7 +838,14 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase): _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.common_config_paths(), {"AutoencoderKL", "AutoencoderTiny"}) + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "AutoencoderKL", + "AutoencoderTiny", + }, + ) base = fields.get("base") or cls._get_base_or_raise(mod) return cls(**fields, base=base) @@ -812,7 +871,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase): @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases: - config = _get_config_or_raise(cls, mod.common_config_paths()) + config = _get_config_or_raise(cls, common_config_paths(mod.path)) if cls._config_looks_like_sdxl(config): return BaseModelType.StableDiffusionXL elif cls._name_looks_like_sdxl(mod): @@ -843,7 +902,14 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M _validate_override_fields(cls, fields) - _validate_class_name(cls, mod.common_config_paths(), {"ControlNetModel", "FluxControlNetModel"}) + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "ControlNetModel", + "FluxControlNetModel", + }, + ) base = fields.get("base") or cls._get_base_or_raise(mod) @@ -851,7 +917,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetDiffusers_SupportedBases: - config = _get_config_or_raise(cls, mod.common_config_paths()) + config = _get_config_or_raise(cls, common_config_paths(mod.path)) if config.get("_class_name") == "FluxControlNetModel": return BaseModelType.Flux @@ -900,7 +966,8 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, @classmethod def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None: - if not mod.has_keys_starting_with( + if has_keys_starting_with( + mod.load_state_dict(), { "controlnet", "control_model", @@ -911,7 +978,7 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so # delicate. "controlnet_blocks", - } + }, ): raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint") @@ -1268,7 +1335,8 @@ class FLUX_Unquantized_CheckpointConfig(CheckpointConfigBase, MainConfigBase, Mo @classmethod def _validate_is_flux(cls, mod: ModelOnDisk) -> None: - if not mod.has_keys_exact( + if not has_keys_exact( + mod.load_state_dict(), { "double_blocks.0.img_attn.norm.key_norm.scale", "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", @@ -1426,7 +1494,7 @@ class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, M _validate_class_name( cls, - mod.common_config_paths(), + common_config_paths(mod.path), { # SD 1.x and 2.x "StableDiffusionPipeline", @@ -1527,7 +1595,7 @@ class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase) _validate_class_name( cls, - mod.common_config_paths(), + common_config_paths(mod.path), { "StableDiffusion3Pipeline", "SD3Transformer2DModel", @@ -1548,7 +1616,7 @@ class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase) @classmethod def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]: # Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json - config = _get_config_or_raise(cls, mod.common_config_paths()) + config = _get_config_or_raise(cls, common_config_paths(mod.path)) submodels: dict[SubModelType, SubmodelDefinition] = {} @@ -1601,8 +1669,10 @@ class CogView4_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigB _validate_class_name( cls, - mod.common_config_paths(), - {"CogView4Pipeline"}, + common_config_paths(mod.path), + { + "CogView4Pipeline", + }, ) repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) @@ -1706,13 +1776,14 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): @classmethod def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None: - if not mod.has_keys_starting_with( + if not has_keys_starting_with( + mod.load_state_dict(), { "image_proj.", "ip_adapter.", # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.". "ip_adapter_proj_model.", - } + }, ): raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics") @@ -1778,7 +1849,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): _validate_class_name( cls, - mod.common_config_paths(), + common_config_paths(mod.path), { "CLIPModel", "CLIPTextModel", @@ -1792,7 +1863,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None: - config = _get_config_or_raise(cls, mod.common_config_paths()) + config = _get_config_or_raise(cls, common_config_paths(mod.path)) clip_variant = _get_clip_variant_type_from_config(config) if clip_variant is not ClipVariantType.G: @@ -1816,7 +1887,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): _validate_class_name( cls, - mod.common_config_paths(), + common_config_paths(mod.path), { "CLIPModel", "CLIPTextModel", @@ -1830,7 +1901,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None: - config = _get_config_or_raise(cls, mod.common_config_paths()) + config = _get_config_or_raise(cls, common_config_paths(mod.path)) clip_variant = _get_clip_variant_type_from_config(config) if clip_variant is not ClipVariantType.L: @@ -1852,7 +1923,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): _validate_class_name( cls, - mod.common_config_paths(), + common_config_paths(mod.path), { "CLIPVisionModelWithProjection", }, @@ -1882,7 +1953,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi _validate_class_name( cls, - mod.common_config_paths(), + common_config_paths(mod.path), { "T2IAdapter", }, @@ -1894,7 +1965,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases: - config = _get_config_or_raise(cls, mod.common_config_paths()) + config = _get_config_or_raise(cls, common_config_paths(mod.path)) adapter_type = config.get("adapter_type") @@ -1955,7 +2026,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase): _validate_class_name( cls, - mod.common_config_paths(), + common_config_paths(mod.path), { "SiglipModel", }, @@ -1998,7 +2069,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): _validate_class_name( cls, - mod.common_config_paths(), + common_config_paths(mod.path), { "LlavaOnevisionForConditionalGeneration", }, diff --git a/invokeai/backend/model_manager/model_on_disk.py b/invokeai/backend/model_manager/model_on_disk.py index 6927200922..502ca596a6 100644 --- a/invokeai/backend/model_manager/model_on_disk.py +++ b/invokeai/backend/model_manager/model_on_disk.py @@ -128,26 +128,3 @@ class ModelOnDisk: f"Please specify the intended file using the 'path' argument" ) return path - - def has_keys_exact(self, keys: str | set[str], path: Optional[Path] = None) -> bool: - _keys = {keys} if isinstance(keys, str) else keys - state_dict = self.load_state_dict(path) - return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)}) - - def has_keys_starting_with(self, prefixes: str | set[str], path: Optional[Path] = None) -> bool: - _prefixes = {prefixes} if isinstance(prefixes, str) else prefixes - state_dict = self.load_state_dict(path) - return any( - any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str) - ) - - def has_keys_ending_with(self, suffixes: str | set[str], path: Optional[Path] = None) -> bool: - _suffixes = {suffixes} if isinstance(suffixes, str) else suffixes - state_dict = self.load_state_dict(path) - return any( - any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str) - ) - - def common_config_paths(self) -> set[Path]: - """Returns common config file paths for models stored in directories.""" - return {self.path / "config.json", self.path / "model_index.json"}