refactor(mm): continued iteration on model identifcation

This commit is contained in:
psychedelicious
2025-10-09 20:53:12 +11:00
parent 0865c99834
commit 1373b440c3
7 changed files with 149 additions and 95 deletions

View File

@@ -15,6 +15,12 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
class InvalidModelConfigException(Exception):
"""Raised when a model configuration is invalid."""
pass
class InstallStatus(str, Enum):
"""State of an install job running in the background."""

View File

@@ -27,6 +27,7 @@ from invokeai.app.services.model_install.model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
InvalidModelConfigException,
LocalModelSource,
ModelInstallJob,
ModelSource,
@@ -599,12 +600,18 @@ class ModelInstallService(ModelInstallServiceBase):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
return ModelConfigFactory.from_model_on_disk(
result = ModelConfigFactory.from_model_on_disk(
mod=model_path,
override_fields=deepcopy(fields),
hash_algo=hash_algo,
allow_unknown=self.app_config.allow_unknown_models,
)
if result.config is None:
raise InvalidModelConfigException(f"Could not identify model type for {model_path}")
return result.config
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
) -> str:

View File

@@ -52,8 +52,10 @@ class Migration23Callback:
# In v6.8.0 we made some improvements to the model taxonomy and the model config schemas. There are a changes
# we need to make to old configs to bring them up to date.
base = config_dict.get("base")
type = config_dict.get("type")
format = config_dict.get("format")
base = config_dict.get("base")
if base == BaseModelType.Flux.value and type == ModelType.Main.value:
# Prior to v6.8.0, we used an awkward combination of `config_path` and `variant` to distinguish between FLUX
# variants.
@@ -90,7 +92,7 @@ class Migration23Callback:
BaseModelType.StableDiffusionXL.value,
BaseModelType.StableDiffusionXLRefiner.value,
}
and type == "main"
and type == ModelType.Main.value
):
# Prior to v6.8.0, the prediction_type field was optional and would default to Epsilon if not present.
# We now make it explicit and always present. Use the existing value if present, otherwise default to
@@ -99,6 +101,20 @@ class Migration23Callback:
# It's only on SD1.x, SD2.x, and SDXL main models.
config_dict["prediction_type"] = config_dict.get("prediction_type", SchedulerPredictionType.Epsilon.value)
if base == BaseModelType.Flux and type == ModelType.LoRA.value and format == ModelFormat.Diffusers.value:
# Prior to v6.8.0, we used the Diffusers format for FLUX LoRA models that used the diffusers _key_
# structure. This was misleading, as everywhere else in the application, we used the Diffusers format
# to indicate that the model files were in the Diffusers _file_ format (i.e. a directory containing
# the weights and config files).
#
# At runtime, we check the LoRA's state dict directly to determine the key structure, so we do not need
# to rely on the format field for this purpose. As of v6.8.0, we always use the LyCORIS format for single-
# file LoRAs, regardless of the key structure.
#
# This change allows LoRA model identification to not need a special case for FLUX LoRAs in the diffusers
# key format.
config_dict["format"] = ModelFormat.LyCORIS.value
if type == ModelType.CLIPVision.value:
# Prior to v6.8.0, some CLIP Vision models were associated with a specific base model architecture:
# - CLIP-ViT-bigG-14-laion2B-39B-b160k is the image encoder for SDXL IP Adapter and was associated with SDXL

View File

@@ -1,4 +1,5 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import (
Union,
@@ -241,6 +242,30 @@ additional logic in the future.
"""
@dataclass
class ModelClassificationResult:
"""Result of attempting to classify a model on disk into a specific model config.
Attributes:
match: The best matching model config, or None if no match was found.
results: A mapping of model config class names to either an instance of that class (if it matched)
or an Exception (if it didn't match or an error occurred during matching).
"""
config: AnyModelConfig | None
details: dict[str, AnyModelConfig | Exception]
@property
def all_matches(self) -> list[AnyModelConfig]:
"""Returns a list of all matching model configs found."""
return [r for r in self.details.values() if isinstance(r, Config_Base)]
@property
def match_count(self) -> int:
"""Returns the number of matching model configs found."""
return len(self.all_matches)
class ModelConfigFactory:
@staticmethod
def from_dict(fields: dict[str, Any]) -> AnyModelConfig:
@@ -311,23 +336,25 @@ class ModelConfigFactory:
path: The path to validate
Raises:
RuntimeError: If the path doesn't look like a model
ValueError: If the path doesn't look like a model
"""
if path.is_file():
# For files, just check the extension
if path.suffix.lower() not in _MODEL_EXTENSIONS:
raise RuntimeError(
raise ValueError(
f"File extension {path.suffix} is not a recognized model format. "
f"Expected one of: {', '.join(sorted(_MODEL_EXTENSIONS))}"
)
else:
# For directories, do a quick file count check with early exit
total_files = 0
for item in path.rglob("*"):
# Ignore hidden files and directories
paths_to_check = (p for p in path.rglob("*") if not p.name.startswith("."))
for item in paths_to_check:
if item.is_file():
total_files += 1
if total_files > _MAX_FILES_IN_MODEL_DIR:
raise RuntimeError(
raise ValueError(
f"Directory contains more than {_MAX_FILES_IN_MODEL_DIR} files. "
"This looks like a general-purpose directory rather than a model. "
"Please provide a path to a specific model file or model directory."
@@ -355,22 +382,58 @@ class ModelConfigFactory:
return False
if not find_model_files(path, 0):
raise RuntimeError(
raise ValueError(
f"No model files or config files found in directory {path}. "
f"Expected to find model files with extensions: {', '.join(sorted(_MODEL_EXTENSIONS))} "
f"or config files: {', '.join(sorted(_CONFIG_FILES))}"
)
@staticmethod
def matches_sort_key(m: AnyModelConfig) -> int:
"""Sort key function to prioritize model config matches in case of multiple matches."""
# It is possible that we have multiple matches. We need to prioritize them.
# Known cases where multiple matches can occur:
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
# a config.json file. Prefer the main model.
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
# sophisticated approach.
match m.type:
case ModelType.Main:
return 0
case ModelType.LoRA:
return 1
case ModelType.CLIPEmbed:
return 2
case _:
return 3
@staticmethod
def from_model_on_disk(
mod: str | Path | ModelOnDisk,
override_fields: dict[str, Any] | None = None,
hash_algo: HASHING_ALGORITHMS = "blake3_single",
) -> AnyModelConfig:
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
allow_unknown: bool = True,
) -> ModelClassificationResult:
"""Classify a model on disk and return the best matching model config.
Args:
mod: The model on disk to classify. Can be a path (str or Path) or a ModelOnDisk instance.
override_fields: Optional dictionary of fields to override. These fields will take precedence
over the values extracted from the model on disk, but this cannot force a match if the
model on disk doesn't actually match the config class.
hash_algo: The hashing algorithm to use when computing the model hash if needed.
Returns:
A ModelClassificationResult containing the best matching model config (or None if no match)
and a mapping of all attempted model config classes to either an instance of that class (if it matched)
or an Exception (if it didn't match or an error occurred during matching).
Raises:
ValueError: If the provided path doesn't look like a model.
"""
if isinstance(mod, Path | str):
mod = ModelOnDisk(Path(mod), hash_algo)
@@ -384,93 +447,57 @@ class ModelConfigFactory:
# Store results as a mapping of config class to either an instance of that class or an exception
# that was raised when trying to build it.
results: dict[str, AnyModelConfig | Exception] = {}
details: dict[str, AnyModelConfig | Exception] = {}
# Try to build an instance of each model config class that uses the classify API.
# Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
# Other exceptions may be raised if something unexpected happens during matching or building.
for config_class in Config_Base.CONFIG_CLASSES:
class_name = config_class.__name__
for candidate_class in filter(lambda x: x is not Unknown_Config, Config_Base.CONFIG_CLASSES):
candidate_name = candidate_class.__name__
try:
instance = config_class.from_model_on_disk(mod, fields)
# Technically, from_model_on_disk returns a Config_Base, but in practice it will always be a member of
# the AnyModelConfig union.
results[class_name] = instance # type: ignore
details[candidate_name] = candidate_class.from_model_on_disk(mod, fields) # type: ignore
except NotAMatchError as e:
results[class_name] = e
logger.debug(f"No match for {config_class.__name__} on model {mod.name}")
# This means the model didn't match this config class. It's not an error, just no match.
details[candidate_name] = e
except ValidationError as e:
# This means the model matched, but we couldn't create the pydantic model instance for the config.
# Maybe invalid overrides were provided?
results[class_name] = e
logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}")
details[candidate_name] = e
except Exception as e:
results[class_name] = e
logger.debug(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
# Some other unexpected error occurred. Store the exception for reporting later.
details[candidate_name] = e
# Extract just the successful matches
# NOTE: This will include Unknown_Config matches, which we will handle later.
matches = [r for r in results.values() if isinstance(r, Config_Base)]
matches = [r for r in details.values() if isinstance(r, Config_Base)]
if not matches:
# No matches at all. This should be very rare, but just in case, we will fall back to Unknown_Config.
msg = f"No model config matched for model {mod.path}"
logger.error(msg)
raise RuntimeError(msg)
if not allow_unknown:
# No matches and we are not allowed to fall back to Unknown_Config
return ModelClassificationResult(config=None, details=details)
else:
# Fall back to Unknown_Config
# This should always succeed as Unknown_Config.from_model_on_disk never raises NotAMatch
config = Unknown_Config.from_model_on_disk(mod, fields)
details[Unknown_Config.__name__] = config
return ModelClassificationResult(config=config, details=details)
# It is possible that we have multiple matches. We need to prioritize them.
#
# Known cases where multiple matches can occur:
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
# a config.json file. Prefer the main model.
#
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
# sophisticated approach.
#
# Unknown models should always be the last resort fallback.
def sort_key(m: AnyModelConfig) -> int:
match m.type:
case ModelType.Main:
return 0
case ModelType.LoRA:
return 1
case ModelType.CLIPEmbed:
return 2
case ModelType.Unknown:
# Unknown should always be tried last as a fallback
return 999
case _:
return 3
matches.sort(key=sort_key)
# Warn if we have multiple non-unknown matches
has_unknown = any(isinstance(m, Unknown_Config) for m in matches)
real_match_count = len(matches) - (1 if has_unknown else 0)
if real_match_count > 1:
logger.warning(
f"Multiple model config classes matched for model {mod.path}: {[type(m).__name__ for m in matches]}."
)
instance = matches[0]
if isinstance(instance, Unknown_Config):
logger.warning(f"Unable to identify model {mod.path}, falling back to Unknown_Config")
else:
logger.info(f"Model {mod.path} classified as {type(instance).__name__}")
matches.sort(key=ModelConfigFactory.matches_sort_key)
config = matches[0]
# Now do any post-processing needed for specific model types/bases/etc.
match instance.type:
match config.type:
case ModelType.Main:
instance.default_settings = MainModelDefaultSettings.from_base(instance.base)
config.default_settings = MainModelDefaultSettings.from_base(config.base)
case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa:
instance.default_settings = ControlAdapterDefaultSettings.from_model_name(instance.name)
config.default_settings = ControlAdapterDefaultSettings.from_model_name(config.name)
case ModelType.LoRA:
instance.default_settings = LoraModelDefaultSettings()
config.default_settings = LoraModelDefaultSettings()
case _:
pass
return instance
return ModelClassificationResult(config=config, details=details)
MODEL_NAME_TO_PREPROCESSOR = {

View File

@@ -5,7 +5,6 @@ from pydantic import Field
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
@@ -30,8 +29,6 @@ class Unknown_Config(Config_Base):
Note: Basic path validation (file extensions, directory structure) is already
performed by ModelConfigFactory before this method is called.
"""
if not app_config.allow_unknown_models:
raise NotAMatchError("unknown models are not allowed by configuration")
cloned_override_fields = deepcopy(override_fields)
cloned_override_fields.pop("base", None)

View File

@@ -116,15 +116,18 @@ class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
},
)
cls._validate_base(mod)
# Unfortunately it is difficult to distinguish SD1 and SDXL VAEs by config alone, so we may need to
# guess based on name if the config is inconclusive.
override_name = override_fields.get("name")
cls._validate_base(mod, override_name)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
def _validate_base(cls, mod: ModelOnDisk, override_name: str | None = None) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
recognized_base = cls._get_base_or_raise(mod, override_name)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@@ -134,25 +137,18 @@ class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
@classmethod
def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
def _name_looks_like_sdxl(cls, mod: ModelOnDisk, override_name: str | None = None) -> bool:
# Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best
# we can do is guess based on name.
return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE))
return bool(re.search(r"xl\b", override_name or mod.path.name, re.IGNORECASE))
@classmethod
def _guess_name(cls, mod: ModelOnDisk) -> str:
name = mod.path.name
if name == "vae":
name = mod.path.parent.name
return name
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk, override_name: str | None = None) -> BaseModelType:
config_dict = get_config_dict_or_raise(common_config_paths(mod.path))
if cls._config_looks_like_sdxl(config_dict):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
elif cls._name_looks_like_sdxl(mod, override_name):
return BaseModelType.StableDiffusionXL
else:
# TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO.

View File

@@ -30,6 +30,7 @@ from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
lora_model_from_flux_control_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
@@ -96,15 +97,19 @@ class LoRALoader(ModelLoader):
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
if config.format in [ModelFormat.Diffusers, ModelFormat.OMI]:
if config.format is ModelFormat.OMI:
# HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically
# distributed as a single file without the associated metadata containing the alpha value. We chose
# alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank
# is a popular choice. For example, in the diffusers training scripts:
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
#
# We assume the same for LyCORIS models in diffusers key format.
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif config.format == ModelFormat.LyCORIS:
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
elif config.format is ModelFormat.LyCORIS:
if is_state_dict_likely_in_flux_diffusers_format(state_dict=state_dict):
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)