mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(mm): continued iteration on model identifcation
This commit is contained in:
@@ -15,6 +15,12 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Raised when a model configuration is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from invokeai.app.services.model_install.model_install_common import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
InvalidModelConfigException,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelSource,
|
||||
@@ -599,12 +600,18 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
hash_algo = self._app_config.hashing_algorithm
|
||||
fields = config.model_dump()
|
||||
|
||||
return ModelConfigFactory.from_model_on_disk(
|
||||
result = ModelConfigFactory.from_model_on_disk(
|
||||
mod=model_path,
|
||||
override_fields=deepcopy(fields),
|
||||
hash_algo=hash_algo,
|
||||
allow_unknown=self.app_config.allow_unknown_models,
|
||||
)
|
||||
|
||||
if result.config is None:
|
||||
raise InvalidModelConfigException(f"Could not identify model type for {model_path}")
|
||||
|
||||
return result.config
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
|
||||
@@ -52,8 +52,10 @@ class Migration23Callback:
|
||||
# In v6.8.0 we made some improvements to the model taxonomy and the model config schemas. There are a changes
|
||||
# we need to make to old configs to bring them up to date.
|
||||
|
||||
base = config_dict.get("base")
|
||||
type = config_dict.get("type")
|
||||
format = config_dict.get("format")
|
||||
base = config_dict.get("base")
|
||||
|
||||
if base == BaseModelType.Flux.value and type == ModelType.Main.value:
|
||||
# Prior to v6.8.0, we used an awkward combination of `config_path` and `variant` to distinguish between FLUX
|
||||
# variants.
|
||||
@@ -90,7 +92,7 @@ class Migration23Callback:
|
||||
BaseModelType.StableDiffusionXL.value,
|
||||
BaseModelType.StableDiffusionXLRefiner.value,
|
||||
}
|
||||
and type == "main"
|
||||
and type == ModelType.Main.value
|
||||
):
|
||||
# Prior to v6.8.0, the prediction_type field was optional and would default to Epsilon if not present.
|
||||
# We now make it explicit and always present. Use the existing value if present, otherwise default to
|
||||
@@ -99,6 +101,20 @@ class Migration23Callback:
|
||||
# It's only on SD1.x, SD2.x, and SDXL main models.
|
||||
config_dict["prediction_type"] = config_dict.get("prediction_type", SchedulerPredictionType.Epsilon.value)
|
||||
|
||||
if base == BaseModelType.Flux and type == ModelType.LoRA.value and format == ModelFormat.Diffusers.value:
|
||||
# Prior to v6.8.0, we used the Diffusers format for FLUX LoRA models that used the diffusers _key_
|
||||
# structure. This was misleading, as everywhere else in the application, we used the Diffusers format
|
||||
# to indicate that the model files were in the Diffusers _file_ format (i.e. a directory containing
|
||||
# the weights and config files).
|
||||
#
|
||||
# At runtime, we check the LoRA's state dict directly to determine the key structure, so we do not need
|
||||
# to rely on the format field for this purpose. As of v6.8.0, we always use the LyCORIS format for single-
|
||||
# file LoRAs, regardless of the key structure.
|
||||
#
|
||||
# This change allows LoRA model identification to not need a special case for FLUX LoRAs in the diffusers
|
||||
# key format.
|
||||
config_dict["format"] = ModelFormat.LyCORIS.value
|
||||
|
||||
if type == ModelType.CLIPVision.value:
|
||||
# Prior to v6.8.0, some CLIP Vision models were associated with a specific base model architecture:
|
||||
# - CLIP-ViT-bigG-14-laion2B-39B-b160k is the image encoder for SDXL IP Adapter and was associated with SDXL
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Union,
|
||||
@@ -241,6 +242,30 @@ additional logic in the future.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelClassificationResult:
|
||||
"""Result of attempting to classify a model on disk into a specific model config.
|
||||
|
||||
Attributes:
|
||||
match: The best matching model config, or None if no match was found.
|
||||
results: A mapping of model config class names to either an instance of that class (if it matched)
|
||||
or an Exception (if it didn't match or an error occurred during matching).
|
||||
"""
|
||||
|
||||
config: AnyModelConfig | None
|
||||
details: dict[str, AnyModelConfig | Exception]
|
||||
|
||||
@property
|
||||
def all_matches(self) -> list[AnyModelConfig]:
|
||||
"""Returns a list of all matching model configs found."""
|
||||
return [r for r in self.details.values() if isinstance(r, Config_Base)]
|
||||
|
||||
@property
|
||||
def match_count(self) -> int:
|
||||
"""Returns the number of matching model configs found."""
|
||||
return len(self.all_matches)
|
||||
|
||||
|
||||
class ModelConfigFactory:
|
||||
@staticmethod
|
||||
def from_dict(fields: dict[str, Any]) -> AnyModelConfig:
|
||||
@@ -311,23 +336,25 @@ class ModelConfigFactory:
|
||||
path: The path to validate
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the path doesn't look like a model
|
||||
ValueError: If the path doesn't look like a model
|
||||
"""
|
||||
if path.is_file():
|
||||
# For files, just check the extension
|
||||
if path.suffix.lower() not in _MODEL_EXTENSIONS:
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
f"File extension {path.suffix} is not a recognized model format. "
|
||||
f"Expected one of: {', '.join(sorted(_MODEL_EXTENSIONS))}"
|
||||
)
|
||||
else:
|
||||
# For directories, do a quick file count check with early exit
|
||||
total_files = 0
|
||||
for item in path.rglob("*"):
|
||||
# Ignore hidden files and directories
|
||||
paths_to_check = (p for p in path.rglob("*") if not p.name.startswith("."))
|
||||
for item in paths_to_check:
|
||||
if item.is_file():
|
||||
total_files += 1
|
||||
if total_files > _MAX_FILES_IN_MODEL_DIR:
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
f"Directory contains more than {_MAX_FILES_IN_MODEL_DIR} files. "
|
||||
"This looks like a general-purpose directory rather than a model. "
|
||||
"Please provide a path to a specific model file or model directory."
|
||||
@@ -355,22 +382,58 @@ class ModelConfigFactory:
|
||||
return False
|
||||
|
||||
if not find_model_files(path, 0):
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
f"No model files or config files found in directory {path}. "
|
||||
f"Expected to find model files with extensions: {', '.join(sorted(_MODEL_EXTENSIONS))} "
|
||||
f"or config files: {', '.join(sorted(_CONFIG_FILES))}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def matches_sort_key(m: AnyModelConfig) -> int:
|
||||
"""Sort key function to prioritize model config matches in case of multiple matches."""
|
||||
|
||||
# It is possible that we have multiple matches. We need to prioritize them.
|
||||
|
||||
# Known cases where multiple matches can occur:
|
||||
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
|
||||
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
|
||||
# a config.json file. Prefer the main model.
|
||||
|
||||
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
|
||||
# sophisticated approach.
|
||||
match m.type:
|
||||
case ModelType.Main:
|
||||
return 0
|
||||
case ModelType.LoRA:
|
||||
return 1
|
||||
case ModelType.CLIPEmbed:
|
||||
return 2
|
||||
case _:
|
||||
return 3
|
||||
|
||||
@staticmethod
|
||||
def from_model_on_disk(
|
||||
mod: str | Path | ModelOnDisk,
|
||||
override_fields: dict[str, Any] | None = None,
|
||||
hash_algo: HASHING_ALGORITHMS = "blake3_single",
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Returns the best matching ModelConfig instance from a model's file/folder path.
|
||||
Raises InvalidModelConfigException if no valid configuration is found.
|
||||
Created to deprecate ModelProbe.probe
|
||||
allow_unknown: bool = True,
|
||||
) -> ModelClassificationResult:
|
||||
"""Classify a model on disk and return the best matching model config.
|
||||
|
||||
Args:
|
||||
mod: The model on disk to classify. Can be a path (str or Path) or a ModelOnDisk instance.
|
||||
override_fields: Optional dictionary of fields to override. These fields will take precedence
|
||||
over the values extracted from the model on disk, but this cannot force a match if the
|
||||
model on disk doesn't actually match the config class.
|
||||
hash_algo: The hashing algorithm to use when computing the model hash if needed.
|
||||
|
||||
Returns:
|
||||
A ModelClassificationResult containing the best matching model config (or None if no match)
|
||||
and a mapping of all attempted model config classes to either an instance of that class (if it matched)
|
||||
or an Exception (if it didn't match or an error occurred during matching).
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided path doesn't look like a model.
|
||||
"""
|
||||
if isinstance(mod, Path | str):
|
||||
mod = ModelOnDisk(Path(mod), hash_algo)
|
||||
@@ -384,93 +447,57 @@ class ModelConfigFactory:
|
||||
|
||||
# Store results as a mapping of config class to either an instance of that class or an exception
|
||||
# that was raised when trying to build it.
|
||||
results: dict[str, AnyModelConfig | Exception] = {}
|
||||
details: dict[str, AnyModelConfig | Exception] = {}
|
||||
|
||||
# Try to build an instance of each model config class that uses the classify API.
|
||||
# Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
|
||||
# Other exceptions may be raised if something unexpected happens during matching or building.
|
||||
for config_class in Config_Base.CONFIG_CLASSES:
|
||||
class_name = config_class.__name__
|
||||
for candidate_class in filter(lambda x: x is not Unknown_Config, Config_Base.CONFIG_CLASSES):
|
||||
candidate_name = candidate_class.__name__
|
||||
try:
|
||||
instance = config_class.from_model_on_disk(mod, fields)
|
||||
# Technically, from_model_on_disk returns a Config_Base, but in practice it will always be a member of
|
||||
# the AnyModelConfig union.
|
||||
results[class_name] = instance # type: ignore
|
||||
details[candidate_name] = candidate_class.from_model_on_disk(mod, fields) # type: ignore
|
||||
except NotAMatchError as e:
|
||||
results[class_name] = e
|
||||
logger.debug(f"No match for {config_class.__name__} on model {mod.name}")
|
||||
# This means the model didn't match this config class. It's not an error, just no match.
|
||||
details[candidate_name] = e
|
||||
except ValidationError as e:
|
||||
# This means the model matched, but we couldn't create the pydantic model instance for the config.
|
||||
# Maybe invalid overrides were provided?
|
||||
results[class_name] = e
|
||||
logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}")
|
||||
details[candidate_name] = e
|
||||
except Exception as e:
|
||||
results[class_name] = e
|
||||
logger.debug(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
|
||||
# Some other unexpected error occurred. Store the exception for reporting later.
|
||||
details[candidate_name] = e
|
||||
|
||||
# Extract just the successful matches
|
||||
# NOTE: This will include Unknown_Config matches, which we will handle later.
|
||||
matches = [r for r in results.values() if isinstance(r, Config_Base)]
|
||||
matches = [r for r in details.values() if isinstance(r, Config_Base)]
|
||||
|
||||
if not matches:
|
||||
# No matches at all. This should be very rare, but just in case, we will fall back to Unknown_Config.
|
||||
msg = f"No model config matched for model {mod.path}"
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
if not allow_unknown:
|
||||
# No matches and we are not allowed to fall back to Unknown_Config
|
||||
return ModelClassificationResult(config=None, details=details)
|
||||
else:
|
||||
# Fall back to Unknown_Config
|
||||
# This should always succeed as Unknown_Config.from_model_on_disk never raises NotAMatch
|
||||
config = Unknown_Config.from_model_on_disk(mod, fields)
|
||||
details[Unknown_Config.__name__] = config
|
||||
return ModelClassificationResult(config=config, details=details)
|
||||
|
||||
# It is possible that we have multiple matches. We need to prioritize them.
|
||||
#
|
||||
# Known cases where multiple matches can occur:
|
||||
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
|
||||
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
|
||||
# a config.json file. Prefer the main model.
|
||||
#
|
||||
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
|
||||
# sophisticated approach.
|
||||
#
|
||||
# Unknown models should always be the last resort fallback.
|
||||
def sort_key(m: AnyModelConfig) -> int:
|
||||
match m.type:
|
||||
case ModelType.Main:
|
||||
return 0
|
||||
case ModelType.LoRA:
|
||||
return 1
|
||||
case ModelType.CLIPEmbed:
|
||||
return 2
|
||||
case ModelType.Unknown:
|
||||
# Unknown should always be tried last as a fallback
|
||||
return 999
|
||||
case _:
|
||||
return 3
|
||||
|
||||
matches.sort(key=sort_key)
|
||||
|
||||
# Warn if we have multiple non-unknown matches
|
||||
has_unknown = any(isinstance(m, Unknown_Config) for m in matches)
|
||||
real_match_count = len(matches) - (1 if has_unknown else 0)
|
||||
if real_match_count > 1:
|
||||
logger.warning(
|
||||
f"Multiple model config classes matched for model {mod.path}: {[type(m).__name__ for m in matches]}."
|
||||
)
|
||||
|
||||
instance = matches[0]
|
||||
if isinstance(instance, Unknown_Config):
|
||||
logger.warning(f"Unable to identify model {mod.path}, falling back to Unknown_Config")
|
||||
else:
|
||||
logger.info(f"Model {mod.path} classified as {type(instance).__name__}")
|
||||
matches.sort(key=ModelConfigFactory.matches_sort_key)
|
||||
config = matches[0]
|
||||
|
||||
# Now do any post-processing needed for specific model types/bases/etc.
|
||||
match instance.type:
|
||||
match config.type:
|
||||
case ModelType.Main:
|
||||
instance.default_settings = MainModelDefaultSettings.from_base(instance.base)
|
||||
config.default_settings = MainModelDefaultSettings.from_base(config.base)
|
||||
case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa:
|
||||
instance.default_settings = ControlAdapterDefaultSettings.from_model_name(instance.name)
|
||||
config.default_settings = ControlAdapterDefaultSettings.from_model_name(config.name)
|
||||
case ModelType.LoRA:
|
||||
instance.default_settings = LoraModelDefaultSettings()
|
||||
config.default_settings = LoraModelDefaultSettings()
|
||||
case _:
|
||||
pass
|
||||
|
||||
return instance
|
||||
return ModelClassificationResult(config=config, details=details)
|
||||
|
||||
|
||||
MODEL_NAME_TO_PREPROCESSOR = {
|
||||
|
||||
@@ -5,7 +5,6 @@ from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.model_manager.configs.base import Config_Base
|
||||
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
@@ -30,8 +29,6 @@ class Unknown_Config(Config_Base):
|
||||
Note: Basic path validation (file extensions, directory structure) is already
|
||||
performed by ModelConfigFactory before this method is called.
|
||||
"""
|
||||
if not app_config.allow_unknown_models:
|
||||
raise NotAMatchError("unknown models are not allowed by configuration")
|
||||
|
||||
cloned_override_fields = deepcopy(override_fields)
|
||||
cloned_override_fields.pop("base", None)
|
||||
|
||||
@@ -116,15 +116,18 @@ class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
|
||||
},
|
||||
)
|
||||
|
||||
cls._validate_base(mod)
|
||||
# Unfortunately it is difficult to distinguish SD1 and SDXL VAEs by config alone, so we may need to
|
||||
# guess based on name if the config is inconclusive.
|
||||
override_name = override_fields.get("name")
|
||||
cls._validate_base(mod, override_name)
|
||||
|
||||
return cls(**override_fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
def _validate_base(cls, mod: ModelOnDisk, override_name: str | None = None) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
recognized_base = cls._get_base_or_raise(mod, override_name)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
|
||||
|
||||
@@ -134,25 +137,18 @@ class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
@classmethod
|
||||
def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
|
||||
def _name_looks_like_sdxl(cls, mod: ModelOnDisk, override_name: str | None = None) -> bool:
|
||||
# Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best
|
||||
# we can do is guess based on name.
|
||||
return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE))
|
||||
return bool(re.search(r"xl\b", override_name or mod.path.name, re.IGNORECASE))
|
||||
|
||||
@classmethod
|
||||
def _guess_name(cls, mod: ModelOnDisk) -> str:
|
||||
name = mod.path.name
|
||||
if name == "vae":
|
||||
name = mod.path.parent.name
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk, override_name: str | None = None) -> BaseModelType:
|
||||
config_dict = get_config_dict_or_raise(common_config_paths(mod.path))
|
||||
if cls._config_looks_like_sdxl(config_dict):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif cls._name_looks_like_sdxl(mod):
|
||||
elif cls._name_looks_like_sdxl(mod, override_name):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
# TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO.
|
||||
|
||||
@@ -30,6 +30,7 @@ from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
|
||||
lora_model_from_flux_control_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
@@ -96,15 +97,19 @@ class LoRALoader(ModelLoader):
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.Flux:
|
||||
if config.format in [ModelFormat.Diffusers, ModelFormat.OMI]:
|
||||
if config.format is ModelFormat.OMI:
|
||||
# HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically
|
||||
# distributed as a single file without the associated metadata containing the alpha value. We chose
|
||||
# alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank
|
||||
# is a popular choice. For example, in the diffusers training scripts:
|
||||
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
|
||||
#
|
||||
# We assume the same for LyCORIS models in diffusers key format.
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||
elif config.format is ModelFormat.LyCORIS:
|
||||
if is_state_dict_likely_in_flux_diffusers_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
||||
elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
|
||||
|
||||
Reference in New Issue
Block a user