From bbecc86d0fe7959069d45511163a4207bc16fc9a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 24 Sep 2025 18:07:51 +1000 Subject: [PATCH] tidy(mm): clarify that model id utils are private --- invokeai/backend/model_manager/config.py | 57 +++++++++++++----------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index c3e196b53c..297812f1cc 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -96,8 +96,11 @@ class NotAMatch(Exception): DEFAULTS_PRECISION = Literal["fp16", "fp32"] +# These utility functions are tightly coupled to the config classes below in order to make the process of raising +# NotAMatch exceptions as easy and consistent as possible. -def get_config_or_raise( + +def _get_config_or_raise( config_class: type, config_path: Path, ) -> dict[str, Any]: @@ -112,14 +115,14 @@ def get_config_or_raise( raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e -def raise_for_class_names( +def _validate_class_names( config_class: type, config_path: Path, valid_class_names: set[str], ) -> None: """Raise NotAMatch if the config file is missing or does not contain a valid class name.""" - config = get_config_or_raise(config_class, config_path) + config = _get_config_or_raise(config_class, config_path) try: if "_class_name" in config: @@ -135,8 +138,8 @@ def raise_for_class_names( raise NotAMatch(config_class, f"model class is not one of {valid_class_names}, got {config_class_name}") -def matches_overrides( - config_class: "Type[AnyModelConfig]", +def _validate_overrides( + config_class: type, provided_overrides: dict[str, Any], valid_overrides: dict[str, Any], ) -> bool: @@ -389,7 +392,7 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -399,7 +402,7 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): if mod.path.is_file(): raise NotAMatch(cls, "model path is a file, not a directory") - raise_for_class_names( + _validate_class_names( config_class=cls, config_path=mod.path / "text_encoder_2" / "config.json", valid_class_names=cls.VALID_CLASS_NAMES, @@ -428,7 +431,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -436,7 +439,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): return cls(**fields) # Heuristic: Look for the T5EncoderModel class name in the config - raise_for_class_names( + _validate_class_names( config_class=cls, config_path=mod.path / "text_encoder_2" / "config.json", valid_class_names=cls.VALID_CLASS_NAMES, @@ -464,7 +467,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -519,7 +522,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -595,7 +598,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -634,7 +637,7 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -681,7 +684,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -691,7 +694,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): if mod.path.is_file(): raise NotAMatch(cls, "model path is a file, not a directory") - raise_for_class_names( + _validate_class_names( config_class=cls, config_path=mod.path / "config.json", valid_class_names=cls.VALID_CLASS_NAMES, @@ -721,7 +724,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): @classmethod def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - config = get_config_or_raise(cls, mod.path / "config.json") + config = _get_config_or_raise(cls, mod.path / "config.json") if cls._config_looks_like_sdxl(config): return BaseModelType.StableDiffusionXL elif cls._name_looks_like_sdxl(mod): @@ -825,7 +828,7 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -858,7 +861,7 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -988,7 +991,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -1000,13 +1003,13 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): config_path = mod.path / "config.json" - raise_for_class_names( + _validate_class_names( config_class=cls, config_path=config_path, valid_class_names=cls.VALID_CLASS_NAMES, ) - config = get_config_or_raise(cls, config_path) + config = _get_config_or_raise(cls, config_path) clip_variant = cls.get_clip_variant_type(config) @@ -1033,7 +1036,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -1045,13 +1048,13 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): config_path = mod.path / "config.json" - raise_for_class_names( + _validate_class_names( config_class=cls, config_path=config_path, valid_class_names=cls.VALID_CLASS_NAMES, ) - config = get_config_or_raise(cls, config_path) + config = _get_config_or_raise(cls, config_path) clip_variant = cls.get_clip_variant_type(config) if clip_variant is not ClipVariantType.L: @@ -1089,7 +1092,7 @@ class SpandrelImageToImageConfig(ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -1147,7 +1150,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - if matches_overrides( + if _validate_overrides( config_class=cls, provided_overrides=fields, valid_overrides=cls.VALID_OVERRIDES, @@ -1159,7 +1162,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): config_path = mod.path / "config.json" - raise_for_class_names( + _validate_class_names( config_class=cls, config_path=config_path, valid_class_names=cls.VALID_CLASS_NAMES,