feat(mm): add sanity checks before probing paths

This commit is contained in:
psychedelicious
2025-10-07 19:30:57 +11:00
parent 25619684c0
commit 01ca74e622
2 changed files with 147 additions and 36 deletions

View File

@@ -109,6 +109,29 @@ from invokeai.backend.model_manager.taxonomy import (
logger = logging.getLogger(__name__)
app_config = get_config()
# Known model file extensions for sanity checking
_MODEL_EXTENSIONS = {
".safetensors",
".ckpt",
".pt",
".pth",
".bin",
".gguf",
".onnx",
}
# Known config file names for diffusers/transformers models
_CONFIG_FILES = {
"model_index.json",
"config.json",
}
# Maximum number of files in a directory to be considered a model
_MAX_FILES_IN_MODEL_DIR = 50
# Maximum depth to search for model files in directories
_MAX_SEARCH_DEPTH = 2
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
@@ -276,6 +299,68 @@ class ModelConfigFactory:
return fields
@staticmethod
def _validate_path_looks_like_model(path: Path) -> None:
"""Perform basic sanity checks to ensure a path looks like a model.
This prevents wasting time trying to identify obviously non-model paths like
home directories or downloads folders. Raises RuntimeError if the path doesn't
pass basic checks.
Args:
path: The path to validate
Raises:
RuntimeError: If the path doesn't look like a model
"""
if path.is_file():
# For files, just check the extension
if path.suffix.lower() not in _MODEL_EXTENSIONS:
raise RuntimeError(
f"File extension {path.suffix} is not a recognized model format. "
f"Expected one of: {', '.join(sorted(_MODEL_EXTENSIONS))}"
)
else:
# For directories, do a quick file count check with early exit
total_files = 0
for item in path.rglob("*"):
if item.is_file():
total_files += 1
if total_files > _MAX_FILES_IN_MODEL_DIR:
raise RuntimeError(
f"Directory contains more than {_MAX_FILES_IN_MODEL_DIR} files. "
"This looks like a general-purpose directory rather than a model. "
"Please provide a path to a specific model file or model directory."
)
# Check if it has config files at root (diffusers/transformers marker)
has_root_config = any((path / config).exists() for config in _CONFIG_FILES)
if has_root_config:
# Has a config file, looks like a valid model directory
return
# Otherwise, search for model files within depth limit
def find_model_files(current_path: Path, depth: int) -> bool:
if depth > _MAX_SEARCH_DEPTH:
return False
try:
for item in current_path.iterdir():
if item.is_file() and item.suffix.lower() in _MODEL_EXTENSIONS:
return True
elif item.is_dir() and find_model_files(item, depth + 1):
return True
except PermissionError:
pass
return False
if not find_model_files(path, 0):
raise RuntimeError(
f"No model files or config files found in directory {path}. "
f"Expected to find model files with extensions: {', '.join(sorted(_MODEL_EXTENSIONS))} "
f"or config files: {', '.join(sorted(_CONFIG_FILES))}"
)
@staticmethod
def from_model_on_disk(
mod: str | Path | ModelOnDisk,
@@ -290,6 +375,10 @@ class ModelConfigFactory:
if isinstance(mod, Path | str):
mod = ModelOnDisk(Path(mod), hash_algo)
# Perform basic sanity checks before attempting any config matching
# This rejects obviously non-model paths early, saving time
ModelConfigFactory._validate_path_looks_like_model(mod.path)
# We will always need these fields to build any model config.
fields = ModelConfigFactory.build_common_fields(mod, override_fields)
@@ -317,48 +406,53 @@ class ModelConfigFactory:
logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}")
except Exception as e:
results[class_name] = e
logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
logger.debug(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
matches = [r for r in results.values() if isinstance(r, Config_Base)]
if not matches and app_config.allow_unknown_models:
logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config")
return Unknown_Config(
**fields,
# Override the type/format/base to ensure it's marked as unknown.
base=BaseModelType.Unknown,
type=ModelType.Unknown,
format=ModelFormat.Unknown,
)
if not matches:
# No matches at all. This should be very rare, but just in case, we will fall back to Unknown_Config.
msg = f"No model config matched for model {mod.path}"
logger.error(msg)
raise RuntimeError(msg)
# It is possible that we have multiple matches. We need to prioritize them.
#
# Known cases where multiple matches can occur:
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
# a config.json file. Prefer the main model.
#
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
# sophisticated approach.
#
# Unknown models should always be the last resort fallback.
def sort_key(m: AnyModelConfig) -> int:
match m.type:
case ModelType.Main:
return 0
case ModelType.LoRA:
return 1
case ModelType.CLIPEmbed:
return 2
case ModelType.Unknown:
# Unknown should always be tried last as a fallback
return 999
case _:
return 3
matches.sort(key=sort_key)
if len(matches) > 1:
# We have multiple matches, in which case at most 1 is correct. We need to pick one.
#
# Known cases:
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
# a config.json file. Prefer the main model.
#
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
# sophisticated approach.
def sort_key(m: AnyModelConfig) -> int:
match m.type:
case ModelType.Main:
return 0
case ModelType.LoRA:
return 1
case ModelType.CLIPEmbed:
return 2
case _:
return 3
matches.sort(key=sort_key)
logger.warning(
f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}."
f"Multiple model config classes matched for model {mod.path}: {[type(m).__name__ for m in matches]}."
)
instance = matches[0]
logger.info(f"Model {mod.name} classified as {type(instance).__name__}")
if isinstance(instance, Unknown_Config):
logger.warning(f"Unable to identify model {mod.path}, falling back to Unknown_Config")
else:
logger.info(f"Model {mod.path} classified as {type(instance).__name__}")
# Now do any post-processing needed for specific model types/bases/etc.
match instance.type:

View File

@@ -2,6 +2,7 @@ from typing import Any, Literal, Self
from pydantic import Field
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
@@ -11,14 +12,30 @@ from invokeai.backend.model_manager.taxonomy import (
ModelType,
)
app_config = get_config()
class Unknown_Config(Config_Base):
"""Model config for unknown models, used as a fallback when we cannot identify a model."""
"""Model config for unknown models, used as a fallback when we cannot positively identify a model."""
base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown)
format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
raise NotAMatchError("unknown model config cannot match any model")
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
"""Create an Unknown_Config for models that couldn't be positively identified.
Note: Basic path validation (file extensions, directory structure) is already
performed by ModelConfigFactory before this method is called.
"""
if not app_config.allow_unknown_models:
raise NotAMatchError("unknown models are not allowed by configuration")
return cls(
**override_fields,
# Override the type/format/base to ensure it's marked as unknown.
base=BaseModelType.Unknown,
type=ModelType.Unknown,
format=ModelFormat.Unknown,
)