mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): add sanity checks before probing paths
This commit is contained in:
@@ -109,6 +109,29 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = get_config()
|
||||
|
||||
# Known model file extensions for sanity checking
|
||||
_MODEL_EXTENSIONS = {
|
||||
".safetensors",
|
||||
".ckpt",
|
||||
".pt",
|
||||
".pth",
|
||||
".bin",
|
||||
".gguf",
|
||||
".onnx",
|
||||
}
|
||||
|
||||
# Known config file names for diffusers/transformers models
|
||||
_CONFIG_FILES = {
|
||||
"model_index.json",
|
||||
"config.json",
|
||||
}
|
||||
|
||||
# Maximum number of files in a directory to be considered a model
|
||||
_MAX_FILES_IN_MODEL_DIR = 50
|
||||
|
||||
# Maximum depth to search for model files in directories
|
||||
_MAX_SEARCH_DEPTH = 2
|
||||
|
||||
|
||||
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
|
||||
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
|
||||
@@ -276,6 +299,68 @@ class ModelConfigFactory:
|
||||
|
||||
return fields
|
||||
|
||||
@staticmethod
|
||||
def _validate_path_looks_like_model(path: Path) -> None:
|
||||
"""Perform basic sanity checks to ensure a path looks like a model.
|
||||
|
||||
This prevents wasting time trying to identify obviously non-model paths like
|
||||
home directories or downloads folders. Raises RuntimeError if the path doesn't
|
||||
pass basic checks.
|
||||
|
||||
Args:
|
||||
path: The path to validate
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the path doesn't look like a model
|
||||
"""
|
||||
if path.is_file():
|
||||
# For files, just check the extension
|
||||
if path.suffix.lower() not in _MODEL_EXTENSIONS:
|
||||
raise RuntimeError(
|
||||
f"File extension {path.suffix} is not a recognized model format. "
|
||||
f"Expected one of: {', '.join(sorted(_MODEL_EXTENSIONS))}"
|
||||
)
|
||||
else:
|
||||
# For directories, do a quick file count check with early exit
|
||||
total_files = 0
|
||||
for item in path.rglob("*"):
|
||||
if item.is_file():
|
||||
total_files += 1
|
||||
if total_files > _MAX_FILES_IN_MODEL_DIR:
|
||||
raise RuntimeError(
|
||||
f"Directory contains more than {_MAX_FILES_IN_MODEL_DIR} files. "
|
||||
"This looks like a general-purpose directory rather than a model. "
|
||||
"Please provide a path to a specific model file or model directory."
|
||||
)
|
||||
|
||||
# Check if it has config files at root (diffusers/transformers marker)
|
||||
has_root_config = any((path / config).exists() for config in _CONFIG_FILES)
|
||||
|
||||
if has_root_config:
|
||||
# Has a config file, looks like a valid model directory
|
||||
return
|
||||
|
||||
# Otherwise, search for model files within depth limit
|
||||
def find_model_files(current_path: Path, depth: int) -> bool:
|
||||
if depth > _MAX_SEARCH_DEPTH:
|
||||
return False
|
||||
try:
|
||||
for item in current_path.iterdir():
|
||||
if item.is_file() and item.suffix.lower() in _MODEL_EXTENSIONS:
|
||||
return True
|
||||
elif item.is_dir() and find_model_files(item, depth + 1):
|
||||
return True
|
||||
except PermissionError:
|
||||
pass
|
||||
return False
|
||||
|
||||
if not find_model_files(path, 0):
|
||||
raise RuntimeError(
|
||||
f"No model files or config files found in directory {path}. "
|
||||
f"Expected to find model files with extensions: {', '.join(sorted(_MODEL_EXTENSIONS))} "
|
||||
f"or config files: {', '.join(sorted(_CONFIG_FILES))}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_model_on_disk(
|
||||
mod: str | Path | ModelOnDisk,
|
||||
@@ -290,6 +375,10 @@ class ModelConfigFactory:
|
||||
if isinstance(mod, Path | str):
|
||||
mod = ModelOnDisk(Path(mod), hash_algo)
|
||||
|
||||
# Perform basic sanity checks before attempting any config matching
|
||||
# This rejects obviously non-model paths early, saving time
|
||||
ModelConfigFactory._validate_path_looks_like_model(mod.path)
|
||||
|
||||
# We will always need these fields to build any model config.
|
||||
fields = ModelConfigFactory.build_common_fields(mod, override_fields)
|
||||
|
||||
@@ -317,48 +406,53 @@ class ModelConfigFactory:
|
||||
logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}")
|
||||
except Exception as e:
|
||||
results[class_name] = e
|
||||
logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
|
||||
logger.debug(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
|
||||
|
||||
matches = [r for r in results.values() if isinstance(r, Config_Base)]
|
||||
|
||||
if not matches and app_config.allow_unknown_models:
|
||||
logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config")
|
||||
return Unknown_Config(
|
||||
**fields,
|
||||
# Override the type/format/base to ensure it's marked as unknown.
|
||||
base=BaseModelType.Unknown,
|
||||
type=ModelType.Unknown,
|
||||
format=ModelFormat.Unknown,
|
||||
)
|
||||
if not matches:
|
||||
# No matches at all. This should be very rare, but just in case, we will fall back to Unknown_Config.
|
||||
msg = f"No model config matched for model {mod.path}"
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# It is possible that we have multiple matches. We need to prioritize them.
|
||||
#
|
||||
# Known cases where multiple matches can occur:
|
||||
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
|
||||
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
|
||||
# a config.json file. Prefer the main model.
|
||||
#
|
||||
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
|
||||
# sophisticated approach.
|
||||
#
|
||||
# Unknown models should always be the last resort fallback.
|
||||
def sort_key(m: AnyModelConfig) -> int:
|
||||
match m.type:
|
||||
case ModelType.Main:
|
||||
return 0
|
||||
case ModelType.LoRA:
|
||||
return 1
|
||||
case ModelType.CLIPEmbed:
|
||||
return 2
|
||||
case ModelType.Unknown:
|
||||
# Unknown should always be tried last as a fallback
|
||||
return 999
|
||||
case _:
|
||||
return 3
|
||||
|
||||
matches.sort(key=sort_key)
|
||||
|
||||
if len(matches) > 1:
|
||||
# We have multiple matches, in which case at most 1 is correct. We need to pick one.
|
||||
#
|
||||
# Known cases:
|
||||
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
|
||||
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
|
||||
# a config.json file. Prefer the main model.
|
||||
#
|
||||
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
|
||||
# sophisticated approach.
|
||||
def sort_key(m: AnyModelConfig) -> int:
|
||||
match m.type:
|
||||
case ModelType.Main:
|
||||
return 0
|
||||
case ModelType.LoRA:
|
||||
return 1
|
||||
case ModelType.CLIPEmbed:
|
||||
return 2
|
||||
case _:
|
||||
return 3
|
||||
|
||||
matches.sort(key=sort_key)
|
||||
logger.warning(
|
||||
f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}."
|
||||
f"Multiple model config classes matched for model {mod.path}: {[type(m).__name__ for m in matches]}."
|
||||
)
|
||||
|
||||
instance = matches[0]
|
||||
logger.info(f"Model {mod.name} classified as {type(instance).__name__}")
|
||||
if isinstance(instance, Unknown_Config):
|
||||
logger.warning(f"Unable to identify model {mod.path}, falling back to Unknown_Config")
|
||||
else:
|
||||
logger.info(f"Model {mod.path} classified as {type(instance).__name__}")
|
||||
|
||||
# Now do any post-processing needed for specific model types/bases/etc.
|
||||
match instance.type:
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Literal, Self
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.model_manager.configs.base import Config_Base
|
||||
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
@@ -11,14 +12,30 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
ModelType,
|
||||
)
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
|
||||
class Unknown_Config(Config_Base):
|
||||
"""Model config for unknown models, used as a fallback when we cannot identify a model."""
|
||||
"""Model config for unknown models, used as a fallback when we cannot positively identify a model."""
|
||||
|
||||
base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
|
||||
type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown)
|
||||
format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown)
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
raise NotAMatchError("unknown model config cannot match any model")
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
"""Create an Unknown_Config for models that couldn't be positively identified.
|
||||
|
||||
Note: Basic path validation (file extensions, directory structure) is already
|
||||
performed by ModelConfigFactory before this method is called.
|
||||
"""
|
||||
if not app_config.allow_unknown_models:
|
||||
raise NotAMatchError("unknown models are not allowed by configuration")
|
||||
|
||||
return cls(
|
||||
**override_fields,
|
||||
# Override the type/format/base to ensure it's marked as unknown.
|
||||
base=BaseModelType.Unknown,
|
||||
type=ModelType.Unknown,
|
||||
format=ModelFormat.Unknown,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user