From 3b606b6d630a3a65c6d5cf953b98d1934cb04b2b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 24 Sep 2025 18:50:06 +1000 Subject: [PATCH] feat(mm): make match helpers more succint --- invokeai/backend/model_manager/config.py | 250 ++++++++--------------- 1 file changed, 89 insertions(+), 161 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 29624eef95..640ab56318 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -170,6 +170,24 @@ def _validate_overrides( return is_perfect_match +def _raise_if_not_file( + config_class: type, + mod: ModelOnDisk, +) -> None: + """Raise NotAMatch if the model path is not a file.""" + if not mod.path.is_file(): + raise NotAMatch(config_class, "model path is not a file") + + +def _raise_if_not_dir( + config_class: type, + mod: ModelOnDisk, +) -> None: + """Raise NotAMatch if the model path is not a directory.""" + if not mod.path.is_dir(): + raise NotAMatch(config_class, "model path is not a directory") + + class SubmodelDefinition(BaseModel): path_or_prefix: str model_type: ModelType @@ -392,21 +410,12 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): + _raise_if_not_dir(cls, mod) + + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): return cls(**fields) - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") - - _validate_class_names( - config_class=cls, - config_path=mod.path / "text_encoder_2" / "config.json", - valid_class_names=cls.VALID_CLASS_NAMES, - ) + _validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES) # Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models) has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists() @@ -431,19 +440,13 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): + _raise_if_not_dir(cls, mod) + + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): return cls(**fields) # Heuristic: Look for the T5EncoderModel class name in the config - _validate_class_names( - config_class=cls, - config_path=mod.path / "text_encoder_2" / "config.json", - valid_class_names=cls.VALID_CLASS_NAMES, - ) + _validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES) # Heuristic: look for the quantization in the filename name filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) @@ -467,16 +470,11 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + # OMI LoRAs are always files + _raise_if_not_file(cls, mod) - # Heuristic: OMI LoRAs are always files, never directories - if mod.path.is_dir(): - raise NotAMatch(cls, "model path is a directory, not a file") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) # Heuristic: differential diagnosis vs ControlLoRA and Diffusers if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: @@ -522,16 +520,11 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + # LyCORIS LoRAs are always files, never directories + _raise_if_not_file(cls, mod) - # Heuristic: LyCORIS LoRAs are always files, never directories - if mod.path.is_dir(): - raise NotAMatch(cls, "model path is a directory, not a file") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) # Heuristic: differential diagnosis vs ControlLoRA and Diffusers if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: @@ -598,16 +591,11 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + # Diffusers-style models always directories + _raise_if_not_dir(cls, mod) - # Heuristic: Diffusers LoRAs are always directories, never files - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers @@ -637,15 +625,10 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + _raise_if_not_file(cls, mod) - if mod.path.is_dir(): - raise NotAMatch(cls, "model path is a directory, not a file") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}): raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") @@ -684,21 +667,12 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): + _raise_if_not_dir(cls, mod) + + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): return cls(**fields) - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") - - _validate_class_names( - config_class=cls, - config_path=mod.path / "config.json", - valid_class_names=cls.VALID_CLASS_NAMES, - ) + _validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES) base = fields.get("base") or cls._get_base_or_raise(mod) return cls(**fields, base=base) @@ -828,15 +802,10 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + _raise_if_not_file(cls, mod) - if mod.path.is_dir(): - raise NotAMatch(cls, "model path is a directory, not a file") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) if not cls._file_looks_like_embedding(mod): raise NotAMatch(cls, "model does not look like a textual inversion embedding file") @@ -861,15 +830,10 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + _raise_if_not_dir(cls, mod) - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) for p in mod.weight_files(): if cls._file_looks_like_embedding(mod, p): @@ -991,23 +955,14 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + _raise_if_not_dir(cls, mod) - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) config_path = mod.path / "config.json" - _validate_class_names( - config_class=cls, - config_path=config_path, - valid_class_names=cls.VALID_CLASS_NAMES, - ) + _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) config = _get_config_or_raise(cls, config_path) @@ -1036,23 +991,14 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + _raise_if_not_dir(cls, mod) - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) config_path = mod.path / "config.json" - _validate_class_names( - config_class=cls, - config_path=config_path, - valid_class_names=cls.VALID_CLASS_NAMES, - ) + _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) config = _get_config_or_raise(cls, config_path) clip_variant = cls.get_clip_variant_type(config) @@ -1080,23 +1026,14 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + _raise_if_not_dir(cls, mod) - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) config_path = mod.path / "config.json" - _validate_class_names( - config_class=cls, - config_path=config_path, - valid_class_names=cls.VALID_CLASS_NAMES, - ) + _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) return cls(**fields) @@ -1123,15 +1060,10 @@ class SpandrelImageToImageConfig(ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + _raise_if_not_file(cls, mod) - if not mod.path.is_file(): - raise NotAMatch(cls, "model path is a directory, not a file") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) try: # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were @@ -1166,33 +1098,38 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + _raise_if_not_dir(cls, mod) - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) config_path = mod.path / "config.json" - _validate_class_names( - config_class=cls, - config_path=config_path, - valid_class_names=cls.VALID_CLASS_NAMES, - ) + _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) return cls(**fields) -class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase): +class FluxReduxConfig(ModelConfigBase): """Model config for FLUX Tools Redux model.""" type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + VALID_OVERRIDES: ClassVar = { + "type": ModelType.SigLIP, + "format": ModelFormat.Diffusers, + } + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _raise_if_not_file(cls, mod) + + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) + + return cls(**fields) + class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): """Model config for Llava Onevision models.""" @@ -1212,23 +1149,14 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if _validate_overrides( - config_class=cls, - provided_overrides=fields, - valid_overrides=cls.VALID_OVERRIDES, - ): - return cls(**fields) + _raise_if_not_dir(cls, mod) - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) config_path = mod.path / "config.json" - _validate_class_names( - config_class=cls, - config_path=config_path, - valid_class_names=cls.VALID_CLASS_NAMES, - ) + _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) return cls(**fields)