From 34be134c09ffa034f871b0344f70fface05402b9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 7 Oct 2025 19:30:57 +1100 Subject: [PATCH] feat(mm): add sanity checks before probing paths --- .../backend/model_manager/configs/factory.py | 160 ++++++++++++++---- .../backend/model_manager/configs/unknown.py | 23 ++- 2 files changed, 147 insertions(+), 36 deletions(-) diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py index 8ff48e097f..03ab40ca5a 100644 --- a/invokeai/backend/model_manager/configs/factory.py +++ b/invokeai/backend/model_manager/configs/factory.py @@ -109,6 +109,29 @@ from invokeai.backend.model_manager.taxonomy import ( logger = logging.getLogger(__name__) app_config = get_config() +# Known model file extensions for sanity checking +_MODEL_EXTENSIONS = { + ".safetensors", + ".ckpt", + ".pt", + ".pth", + ".bin", + ".gguf", + ".onnx", +} + +# Known config file names for diffusers/transformers models +_CONFIG_FILES = { + "model_index.json", + "config.json", +} + +# Maximum number of files in a directory to be considered a model +_MAX_FILES_IN_MODEL_DIR = 50 + +# Maximum depth to search for model files in directories +_MAX_SEARCH_DEPTH = 2 + # The types are listed explicitly because IDEs/LSPs can't identify the correct types # when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes @@ -276,6 +299,68 @@ class ModelConfigFactory: return fields + @staticmethod + def _validate_path_looks_like_model(path: Path) -> None: + """Perform basic sanity checks to ensure a path looks like a model. + + This prevents wasting time trying to identify obviously non-model paths like + home directories or downloads folders. Raises RuntimeError if the path doesn't + pass basic checks. + + Args: + path: The path to validate + + Raises: + RuntimeError: If the path doesn't look like a model + """ + if path.is_file(): + # For files, just check the extension + if path.suffix.lower() not in _MODEL_EXTENSIONS: + raise RuntimeError( + f"File extension {path.suffix} is not a recognized model format. " + f"Expected one of: {', '.join(sorted(_MODEL_EXTENSIONS))}" + ) + else: + # For directories, do a quick file count check with early exit + total_files = 0 + for item in path.rglob("*"): + if item.is_file(): + total_files += 1 + if total_files > _MAX_FILES_IN_MODEL_DIR: + raise RuntimeError( + f"Directory contains more than {_MAX_FILES_IN_MODEL_DIR} files. " + "This looks like a general-purpose directory rather than a model. " + "Please provide a path to a specific model file or model directory." + ) + + # Check if it has config files at root (diffusers/transformers marker) + has_root_config = any((path / config).exists() for config in _CONFIG_FILES) + + if has_root_config: + # Has a config file, looks like a valid model directory + return + + # Otherwise, search for model files within depth limit + def find_model_files(current_path: Path, depth: int) -> bool: + if depth > _MAX_SEARCH_DEPTH: + return False + try: + for item in current_path.iterdir(): + if item.is_file() and item.suffix.lower() in _MODEL_EXTENSIONS: + return True + elif item.is_dir() and find_model_files(item, depth + 1): + return True + except PermissionError: + pass + return False + + if not find_model_files(path, 0): + raise RuntimeError( + f"No model files or config files found in directory {path}. " + f"Expected to find model files with extensions: {', '.join(sorted(_MODEL_EXTENSIONS))} " + f"or config files: {', '.join(sorted(_CONFIG_FILES))}" + ) + @staticmethod def from_model_on_disk( mod: str | Path | ModelOnDisk, @@ -290,6 +375,10 @@ class ModelConfigFactory: if isinstance(mod, Path | str): mod = ModelOnDisk(Path(mod), hash_algo) + # Perform basic sanity checks before attempting any config matching + # This rejects obviously non-model paths early, saving time + ModelConfigFactory._validate_path_looks_like_model(mod.path) + # We will always need these fields to build any model config. fields = ModelConfigFactory.build_common_fields(mod, override_fields) @@ -317,48 +406,53 @@ class ModelConfigFactory: logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}") except Exception as e: results[class_name] = e - logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}") + logger.debug(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}") matches = [r for r in results.values() if isinstance(r, Config_Base)] - if not matches and app_config.allow_unknown_models: - logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config") - return Unknown_Config( - **fields, - # Override the type/format/base to ensure it's marked as unknown. - base=BaseModelType.Unknown, - type=ModelType.Unknown, - format=ModelFormat.Unknown, - ) + if not matches: + # No matches at all. This should be very rare, but just in case, we will fall back to Unknown_Config. + msg = f"No model config matched for model {mod.path}" + logger.error(msg) + raise RuntimeError(msg) + + # It is possible that we have multiple matches. We need to prioritize them. + # + # Known cases where multiple matches can occur: + # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model. + # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with + # a config.json file. Prefer the main model. + # + # Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more + # sophisticated approach. + # + # Unknown models should always be the last resort fallback. + def sort_key(m: AnyModelConfig) -> int: + match m.type: + case ModelType.Main: + return 0 + case ModelType.LoRA: + return 1 + case ModelType.CLIPEmbed: + return 2 + case ModelType.Unknown: + # Unknown should always be tried last as a fallback + return 999 + case _: + return 3 + + matches.sort(key=sort_key) if len(matches) > 1: - # We have multiple matches, in which case at most 1 is correct. We need to pick one. - # - # Known cases: - # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model. - # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with - # a config.json file. Prefer the main model. - # - # Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more - # sophisticated approach. - def sort_key(m: AnyModelConfig) -> int: - match m.type: - case ModelType.Main: - return 0 - case ModelType.LoRA: - return 1 - case ModelType.CLIPEmbed: - return 2 - case _: - return 3 - - matches.sort(key=sort_key) logger.warning( - f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}." + f"Multiple model config classes matched for model {mod.path}: {[type(m).__name__ for m in matches]}." ) instance = matches[0] - logger.info(f"Model {mod.name} classified as {type(instance).__name__}") + if isinstance(instance, Unknown_Config): + logger.warning(f"Unable to identify model {mod.path}, falling back to Unknown_Config") + else: + logger.info(f"Model {mod.path} classified as {type(instance).__name__}") # Now do any post-processing needed for specific model types/bases/etc. match instance.type: diff --git a/invokeai/backend/model_manager/configs/unknown.py b/invokeai/backend/model_manager/configs/unknown.py index 10aad75566..2371cca089 100644 --- a/invokeai/backend/model_manager/configs/unknown.py +++ b/invokeai/backend/model_manager/configs/unknown.py @@ -2,6 +2,7 @@ from typing import Any, Literal, Self from pydantic import Field +from invokeai.app.services.config.config_default import get_config from invokeai.backend.model_manager.configs.base import Config_Base from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError from invokeai.backend.model_manager.model_on_disk import ModelOnDisk @@ -11,14 +12,30 @@ from invokeai.backend.model_manager.taxonomy import ( ModelType, ) +app_config = get_config() + class Unknown_Config(Config_Base): - """Model config for unknown models, used as a fallback when we cannot identify a model.""" + """Model config for unknown models, used as a fallback when we cannot positively identify a model.""" base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown) type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown) format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown) @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - raise NotAMatchError("unknown model config cannot match any model") + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + """Create an Unknown_Config for models that couldn't be positively identified. + + Note: Basic path validation (file extensions, directory structure) is already + performed by ModelConfigFactory before this method is called. + """ + if not app_config.allow_unknown_models: + raise NotAMatchError("unknown models are not allowed by configuration") + + return cls( + **override_fields, + # Override the type/format/base to ensure it's marked as unknown. + base=BaseModelType.Unknown, + type=ModelType.Unknown, + format=ModelFormat.Unknown, + )