From 2206b285766ec4c048da226c5335a5b79292c64d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 9 Oct 2025 20:53:12 +1100 Subject: [PATCH] refactor(mm): continued iteration on model identifcation --- .../model_install/model_install_common.py | 6 + .../model_install/model_install_default.py | 9 +- .../migrations/migration_23.py | 20 +- .../backend/model_manager/configs/factory.py | 171 ++++++++++-------- .../backend/model_manager/configs/unknown.py | 3 - invokeai/backend/model_manager/configs/vae.py | 24 +-- .../model_manager/load/model_loaders/lora.py | 11 +- 7 files changed, 149 insertions(+), 95 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py index 67832466f3..f098a73d88 100644 --- a/invokeai/app/services/model_install/model_install_common.py +++ b/invokeai/app/services/model_install/model_install_common.py @@ -15,6 +15,12 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType +class InvalidModelConfigException(Exception): + """Raised when a model configuration is invalid.""" + + pass + + class InstallStatus(str, Enum): """State of an install job running in the background.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 53bb5cc12d..e2ac60435a 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -27,6 +27,7 @@ from invokeai.app.services.model_install.model_install_common import ( MODEL_SOURCE_TO_TYPE_MAP, HFModelSource, InstallStatus, + InvalidModelConfigException, LocalModelSource, ModelInstallJob, ModelSource, @@ -599,12 +600,18 @@ class ModelInstallService(ModelInstallServiceBase): hash_algo = self._app_config.hashing_algorithm fields = config.model_dump() - return ModelConfigFactory.from_model_on_disk( + result = ModelConfigFactory.from_model_on_disk( mod=model_path, override_fields=deepcopy(fields), hash_algo=hash_algo, + allow_unknown=self.app_config.allow_unknown_models, ) + if result.config is None: + raise InvalidModelConfigException(f"Could not identify model type for {model_path}") + + return result.config + def _register( self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None ) -> str: diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_23.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_23.py index 8c4245002d..f02cf04f0b 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_23.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_23.py @@ -52,8 +52,10 @@ class Migration23Callback: # In v6.8.0 we made some improvements to the model taxonomy and the model config schemas. There are a changes # we need to make to old configs to bring them up to date. - base = config_dict.get("base") type = config_dict.get("type") + format = config_dict.get("format") + base = config_dict.get("base") + if base == BaseModelType.Flux.value and type == ModelType.Main.value: # Prior to v6.8.0, we used an awkward combination of `config_path` and `variant` to distinguish between FLUX # variants. @@ -90,7 +92,7 @@ class Migration23Callback: BaseModelType.StableDiffusionXL.value, BaseModelType.StableDiffusionXLRefiner.value, } - and type == "main" + and type == ModelType.Main.value ): # Prior to v6.8.0, the prediction_type field was optional and would default to Epsilon if not present. # We now make it explicit and always present. Use the existing value if present, otherwise default to @@ -99,6 +101,20 @@ class Migration23Callback: # It's only on SD1.x, SD2.x, and SDXL main models. config_dict["prediction_type"] = config_dict.get("prediction_type", SchedulerPredictionType.Epsilon.value) + if base == BaseModelType.Flux and type == ModelType.LoRA.value and format == ModelFormat.Diffusers.value: + # Prior to v6.8.0, we used the Diffusers format for FLUX LoRA models that used the diffusers _key_ + # structure. This was misleading, as everywhere else in the application, we used the Diffusers format + # to indicate that the model files were in the Diffusers _file_ format (i.e. a directory containing + # the weights and config files). + # + # At runtime, we check the LoRA's state dict directly to determine the key structure, so we do not need + # to rely on the format field for this purpose. As of v6.8.0, we always use the LyCORIS format for single- + # file LoRAs, regardless of the key structure. + # + # This change allows LoRA model identification to not need a special case for FLUX LoRAs in the diffusers + # key format. + config_dict["format"] = ModelFormat.LyCORIS.value + if type == ModelType.CLIPVision.value: # Prior to v6.8.0, some CLIP Vision models were associated with a specific base model architecture: # - CLIP-ViT-bigG-14-laion2B-39B-b160k is the image encoder for SDXL IP Adapter and was associated with SDXL diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py index 042c08d976..dcd7c4c0ed 100644 --- a/invokeai/backend/model_manager/configs/factory.py +++ b/invokeai/backend/model_manager/configs/factory.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from pathlib import Path from typing import ( Union, @@ -241,6 +242,30 @@ additional logic in the future. """ +@dataclass +class ModelClassificationResult: + """Result of attempting to classify a model on disk into a specific model config. + + Attributes: + match: The best matching model config, or None if no match was found. + results: A mapping of model config class names to either an instance of that class (if it matched) + or an Exception (if it didn't match or an error occurred during matching). + """ + + config: AnyModelConfig | None + details: dict[str, AnyModelConfig | Exception] + + @property + def all_matches(self) -> list[AnyModelConfig]: + """Returns a list of all matching model configs found.""" + return [r for r in self.details.values() if isinstance(r, Config_Base)] + + @property + def match_count(self) -> int: + """Returns the number of matching model configs found.""" + return len(self.all_matches) + + class ModelConfigFactory: @staticmethod def from_dict(fields: dict[str, Any]) -> AnyModelConfig: @@ -311,23 +336,25 @@ class ModelConfigFactory: path: The path to validate Raises: - RuntimeError: If the path doesn't look like a model + ValueError: If the path doesn't look like a model """ if path.is_file(): # For files, just check the extension if path.suffix.lower() not in _MODEL_EXTENSIONS: - raise RuntimeError( + raise ValueError( f"File extension {path.suffix} is not a recognized model format. " f"Expected one of: {', '.join(sorted(_MODEL_EXTENSIONS))}" ) else: # For directories, do a quick file count check with early exit total_files = 0 - for item in path.rglob("*"): + # Ignore hidden files and directories + paths_to_check = (p for p in path.rglob("*") if not p.name.startswith(".")) + for item in paths_to_check: if item.is_file(): total_files += 1 if total_files > _MAX_FILES_IN_MODEL_DIR: - raise RuntimeError( + raise ValueError( f"Directory contains more than {_MAX_FILES_IN_MODEL_DIR} files. " "This looks like a general-purpose directory rather than a model. " "Please provide a path to a specific model file or model directory." @@ -355,22 +382,58 @@ class ModelConfigFactory: return False if not find_model_files(path, 0): - raise RuntimeError( + raise ValueError( f"No model files or config files found in directory {path}. " f"Expected to find model files with extensions: {', '.join(sorted(_MODEL_EXTENSIONS))} " f"or config files: {', '.join(sorted(_CONFIG_FILES))}" ) + @staticmethod + def matches_sort_key(m: AnyModelConfig) -> int: + """Sort key function to prioritize model config matches in case of multiple matches.""" + + # It is possible that we have multiple matches. We need to prioritize them. + + # Known cases where multiple matches can occur: + # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model. + # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with + # a config.json file. Prefer the main model. + + # Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more + # sophisticated approach. + match m.type: + case ModelType.Main: + return 0 + case ModelType.LoRA: + return 1 + case ModelType.CLIPEmbed: + return 2 + case _: + return 3 + @staticmethod def from_model_on_disk( mod: str | Path | ModelOnDisk, override_fields: dict[str, Any] | None = None, hash_algo: HASHING_ALGORITHMS = "blake3_single", - ) -> AnyModelConfig: - """ - Returns the best matching ModelConfig instance from a model's file/folder path. - Raises InvalidModelConfigException if no valid configuration is found. - Created to deprecate ModelProbe.probe + allow_unknown: bool = True, + ) -> ModelClassificationResult: + """Classify a model on disk and return the best matching model config. + + Args: + mod: The model on disk to classify. Can be a path (str or Path) or a ModelOnDisk instance. + override_fields: Optional dictionary of fields to override. These fields will take precedence + over the values extracted from the model on disk, but this cannot force a match if the + model on disk doesn't actually match the config class. + hash_algo: The hashing algorithm to use when computing the model hash if needed. + + Returns: + A ModelClassificationResult containing the best matching model config (or None if no match) + and a mapping of all attempted model config classes to either an instance of that class (if it matched) + or an Exception (if it didn't match or an error occurred during matching). + + Raises: + ValueError: If the provided path doesn't look like a model. """ if isinstance(mod, Path | str): mod = ModelOnDisk(Path(mod), hash_algo) @@ -384,93 +447,57 @@ class ModelConfigFactory: # Store results as a mapping of config class to either an instance of that class or an exception # that was raised when trying to build it. - results: dict[str, AnyModelConfig | Exception] = {} + details: dict[str, AnyModelConfig | Exception] = {} # Try to build an instance of each model config class that uses the classify API. # Each class will either return an instance of itself or raise NotAMatch if it doesn't match. # Other exceptions may be raised if something unexpected happens during matching or building. - for config_class in Config_Base.CONFIG_CLASSES: - class_name = config_class.__name__ + for candidate_class in filter(lambda x: x is not Unknown_Config, Config_Base.CONFIG_CLASSES): + candidate_name = candidate_class.__name__ try: - instance = config_class.from_model_on_disk(mod, fields) # Technically, from_model_on_disk returns a Config_Base, but in practice it will always be a member of # the AnyModelConfig union. - results[class_name] = instance # type: ignore + details[candidate_name] = candidate_class.from_model_on_disk(mod, fields) # type: ignore except NotAMatchError as e: - results[class_name] = e - logger.debug(f"No match for {config_class.__name__} on model {mod.name}") + # This means the model didn't match this config class. It's not an error, just no match. + details[candidate_name] = e except ValidationError as e: # This means the model matched, but we couldn't create the pydantic model instance for the config. # Maybe invalid overrides were provided? - results[class_name] = e - logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}") + details[candidate_name] = e except Exception as e: - results[class_name] = e - logger.debug(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}") + # Some other unexpected error occurred. Store the exception for reporting later. + details[candidate_name] = e # Extract just the successful matches - # NOTE: This will include Unknown_Config matches, which we will handle later. - matches = [r for r in results.values() if isinstance(r, Config_Base)] + matches = [r for r in details.values() if isinstance(r, Config_Base)] if not matches: - # No matches at all. This should be very rare, but just in case, we will fall back to Unknown_Config. - msg = f"No model config matched for model {mod.path}" - logger.error(msg) - raise RuntimeError(msg) + if not allow_unknown: + # No matches and we are not allowed to fall back to Unknown_Config + return ModelClassificationResult(config=None, details=details) + else: + # Fall back to Unknown_Config + # This should always succeed as Unknown_Config.from_model_on_disk never raises NotAMatch + config = Unknown_Config.from_model_on_disk(mod, fields) + details[Unknown_Config.__name__] = config + return ModelClassificationResult(config=config, details=details) - # It is possible that we have multiple matches. We need to prioritize them. - # - # Known cases where multiple matches can occur: - # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model. - # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with - # a config.json file. Prefer the main model. - # - # Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more - # sophisticated approach. - # - # Unknown models should always be the last resort fallback. - def sort_key(m: AnyModelConfig) -> int: - match m.type: - case ModelType.Main: - return 0 - case ModelType.LoRA: - return 1 - case ModelType.CLIPEmbed: - return 2 - case ModelType.Unknown: - # Unknown should always be tried last as a fallback - return 999 - case _: - return 3 - - matches.sort(key=sort_key) - - # Warn if we have multiple non-unknown matches - has_unknown = any(isinstance(m, Unknown_Config) for m in matches) - real_match_count = len(matches) - (1 if has_unknown else 0) - if real_match_count > 1: - logger.warning( - f"Multiple model config classes matched for model {mod.path}: {[type(m).__name__ for m in matches]}." - ) - - instance = matches[0] - if isinstance(instance, Unknown_Config): - logger.warning(f"Unable to identify model {mod.path}, falling back to Unknown_Config") - else: - logger.info(f"Model {mod.path} classified as {type(instance).__name__}") + matches.sort(key=ModelConfigFactory.matches_sort_key) + config = matches[0] # Now do any post-processing needed for specific model types/bases/etc. - match instance.type: + match config.type: case ModelType.Main: - instance.default_settings = MainModelDefaultSettings.from_base(instance.base) + config.default_settings = MainModelDefaultSettings.from_base(config.base) case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa: - instance.default_settings = ControlAdapterDefaultSettings.from_model_name(instance.name) + config.default_settings = ControlAdapterDefaultSettings.from_model_name(config.name) case ModelType.LoRA: - instance.default_settings = LoraModelDefaultSettings() + config.default_settings = LoraModelDefaultSettings() case _: pass - return instance + return ModelClassificationResult(config=config, details=details) MODEL_NAME_TO_PREPROCESSOR = { diff --git a/invokeai/backend/model_manager/configs/unknown.py b/invokeai/backend/model_manager/configs/unknown.py index a145091452..42ac62ebca 100644 --- a/invokeai/backend/model_manager/configs/unknown.py +++ b/invokeai/backend/model_manager/configs/unknown.py @@ -5,7 +5,6 @@ from pydantic import Field from invokeai.app.services.config.config_default import get_config from invokeai.backend.model_manager.configs.base import Config_Base -from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError from invokeai.backend.model_manager.model_on_disk import ModelOnDisk from invokeai.backend.model_manager.taxonomy import ( BaseModelType, @@ -30,8 +29,6 @@ class Unknown_Config(Config_Base): Note: Basic path validation (file extensions, directory structure) is already performed by ModelConfigFactory before this method is called. """ - if not app_config.allow_unknown_models: - raise NotAMatchError("unknown models are not allowed by configuration") cloned_override_fields = deepcopy(override_fields) cloned_override_fields.pop("base", None) diff --git a/invokeai/backend/model_manager/configs/vae.py b/invokeai/backend/model_manager/configs/vae.py index c171595bd6..2525e0a1e4 100644 --- a/invokeai/backend/model_manager/configs/vae.py +++ b/invokeai/backend/model_manager/configs/vae.py @@ -116,15 +116,18 @@ class VAE_Diffusers_Config_Base(Diffusers_Config_Base): }, ) - cls._validate_base(mod) + # Unfortunately it is difficult to distinguish SD1 and SDXL VAEs by config alone, so we may need to + # guess based on name if the config is inconclusive. + override_name = override_fields.get("name") + cls._validate_base(mod, override_name) return cls(**override_fields) @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: + def _validate_base(cls, mod: ModelOnDisk, override_name: str | None = None) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) + recognized_base = cls._get_base_or_raise(mod, override_name) if expected_base is not recognized_base: raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") @@ -134,25 +137,18 @@ class VAE_Diffusers_Config_Base(Diffusers_Config_Base): return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] @classmethod - def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool: + def _name_looks_like_sdxl(cls, mod: ModelOnDisk, override_name: str | None = None) -> bool: # Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down # by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best # we can do is guess based on name. - return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE)) + return bool(re.search(r"xl\b", override_name or mod.path.name, re.IGNORECASE)) @classmethod - def _guess_name(cls, mod: ModelOnDisk) -> str: - name = mod.path.name - if name == "vae": - name = mod.path.parent.name - return name - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + def _get_base_or_raise(cls, mod: ModelOnDisk, override_name: str | None = None) -> BaseModelType: config_dict = get_config_dict_or_raise(common_config_paths(mod.path)) if cls._config_looks_like_sdxl(config_dict): return BaseModelType.StableDiffusionXL - elif cls._name_looks_like_sdxl(mod): + elif cls._name_looks_like_sdxl(mod, override_name): return BaseModelType.StableDiffusionXL else: # TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO. diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 29fb815d54..b97c3efeb1 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -30,6 +30,7 @@ from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import ( lora_model_from_flux_control_state_dict, ) from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( + is_state_dict_likely_in_flux_diffusers_format, lora_model_from_flux_diffusers_state_dict, ) from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( @@ -96,15 +97,19 @@ class LoRALoader(ModelLoader): state_dict = convert_sdxl_keys_to_diffusers_format(state_dict) model = lora_model_from_sd_state_dict(state_dict=state_dict) elif self._model_base == BaseModelType.Flux: - if config.format in [ModelFormat.Diffusers, ModelFormat.OMI]: + if config.format is ModelFormat.OMI: # HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically # distributed as a single file without the associated metadata containing the alpha value. We chose # alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank # is a popular choice. For example, in the diffusers training scripts: # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194 + # + # We assume the same for LyCORIS models in diffusers key format. model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) - elif config.format == ModelFormat.LyCORIS: - if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): + elif config.format is ModelFormat.LyCORIS: + if is_state_dict_likely_in_flux_diffusers_format(state_dict=state_dict): + model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) + elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict) elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict): model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)