From e278b120b1c151d83c3a9fb63054edc59b72d1cd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 24 Sep 2025 16:33:03 +1000 Subject: [PATCH] refactor(mm): simplify model classification process Previously, we had a multi-phase strategy to identify models from their files on disk: 1. Run each model config classes' `matches()` method on the files. It checks if the model could possibly be an identified as the candidate model type. This was intended to be a quick check. Break on the first match. 2. If we have a match, run the config class's `parse()` method. It derive some additional model config attrs from the model files. This was intended to encapsulate heavier operations that may require loading the model into memory. 3. Derive the common model config attrs, like name, description, calculate the hash, etc. Some of these are also heavier operations. This strategy has some issues: - It is not clear how the pieces fit together. There is some back-and-forth between different methods and the config base class. It is hard to trace the flow of logic until you fully wrap your head around the system and therefore difficult to add a model architecture to the probe. - The assumption that we could do quick, lightweight checks before heavier checks is incorrect. We often _must_ load the model state dict in the `matches()` method. So there is no practical perf benefit to splitting up the responsibility of `matches()` and `parse()`. - Sometimes we need to do the same checks in `matches()` and `parse()`. In these cases, splitting the logic is has a negative perf impact because we are doing the same work twice. - As we introduce the concept of an "unknown" model config (i.e. a model that we cannot identify, but still record in the db; see #8582), we will _always_ run _all_ the checks for every model. Therefore we need not try to defer heavier checks or resource-intensive ops like hashing. We are going to do them anyways. - There are situations where a model may match multiple configs. One known case are SD pipeline models with merged LoRAs. In the old probe API, we relied on the implicit order of checks to know that if a model matched for pipeline _and_ LoRA, we prefer the pipeline match. But, in the new API, we do not have this implicit ordering of checks. To resolve this in a resilient way, we need to get all matches up front, then use tie-breaker logic to figure out which should win (or add "differential diagnosis" logic to the matchers). - Field overrides weren't handled well by this strategy. They were only applied at the very end, if a model matched successfully. This means we cannot tell the system "Hey, this model is type X with base Y. Trust me bro.". We cannot override the match logic. As we move towards letting users correct mis-identified models (see #8582), this is a requirement. We can simplify the process significantly and better support "unknown" models. Firstly, model config classes now have a single `from_model_on_disk()` method that attempts to construct an instance of the class from the model files. This replaces the `matches()` and `parse()` methods. If we fail to create the config instance, a special exception is raised that indicates why we think the files cannot be identified as the given model config class. Next, the flow for model identification is a bit simpler: - Derive all the common fields up-front (name, desc, hash, etc). - Merge in overrides. - Call `from_model_on_disk()` for every config class, passing in the fields. Overrides are handled in this method. - Record the results for each config class and choose the best one. The identification logic is a bit more verbose, with the special exceptions and handling of overrides, but it is very clear what is happening. The one downside I can think of for this strategy is we do need to check every model type, instead of stopping at the first match. It's a bit less efficient. In practice, however, this isn't a hot code path, and the improved clarity is worth far more than perf optimizations that the end user will likely never notice. --- .../model_install/model_install_default.py | 2 +- invokeai/backend/model_manager/config.py | 667 +++++++++++++++++- 2 files changed, 663 insertions(+), 6 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 5bc9af8e6b..cae423afe3 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -612,7 +612,7 @@ class ModelInstallService(ModelInstallServiceBase): try: return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore except InvalidModelConfigException: - return ModelConfigBase.classify(mod=model_path, hash_algo=hash_algo, **fields) + return ModelConfigBase.classify(mod=model_path, fields=deepcopy(fields), hash_algo=hash_algo) def _register( self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 9ab1cdcc0f..a5c2058e1b 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -21,6 +21,7 @@ Validation errors will raise an InvalidModelConfigException error. """ # pyright: reportIncompatibleVariableOverride=false +from dataclasses import dataclass import json import logging import re @@ -29,11 +30,19 @@ from abc import ABC, abstractmethod from enum import Enum from inspect import isabstract from pathlib import Path -from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union +from typing import ( + ClassVar, + Literal, + Optional, + Self, + Type, + TypeAlias, + Union, +) import spandrel import torch -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter +from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError from typing_extensions import Annotated, Any, Dict from invokeai.app.services.config.config_default import get_config @@ -71,6 +80,18 @@ class InvalidModelConfigException(Exception): pass +class NotAMatch(Exception): + """Exception for when a model does not match a config class. + + Args: + config_class: The config class that was being tested. + reason: The reason why the model did not match. + """ + + def __init__(self, config_class: "Type[ModelConfigBase]", reason: str): + super().__init__(f"{config_class.__name__} does not match: {reason}") + + DEFAULTS_PRECISION = Literal["fp16", "fp32"] @@ -190,8 +211,8 @@ class ModelConfigBase(ABC, BaseModel): ) usage_info: Optional[str] = Field(default=None, description="Usage information for this model") - USING_LEGACY_PROBE: ClassVar[set[Type["ModelConfigBase"]]] = set() - USING_CLASSIFY_API: ClassVar[set[Type["ModelConfigBase"]]] = set() + USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set() + USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set() _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED def __init_subclass__(cls, **kwargs): @@ -289,6 +310,13 @@ class ModelConfigBase(ABC, BaseModel): Returns a MatchCertainty score.""" pass + @classmethod + @abstractmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + """Performs a quick check to determine if the config matches the model. + Returns a MatchCertainty score.""" + pass + @staticmethod def cast_overrides(**overrides): """Casts user overrides from str to Enum""" @@ -308,7 +336,7 @@ class ModelConfigBase(ABC, BaseModel): overrides["variant"] = variant_type_adapter.validate_strings(overrides["variant"]) @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, **overrides): + def from_model_on_disk_2(cls, mod: ModelOnDisk, **overrides): """Creates an instance of this config or raises InvalidModelConfigException.""" fields = cls.parse(mod) cls.cast_overrides(**overrides) @@ -424,6 +452,11 @@ class T5EncoderConfigBase(ABC, BaseModel): return {} +def load_json(path: Path) -> dict[str, Any]: + with open(path, "r") as file: + return json.load(file) + + class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder @@ -456,6 +489,45 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): return MatchCertainty.NEVER + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.T5Encoder: + raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder") + + if format_override is not None and format_override is not ModelFormat.T5Encoder: + raise NotAMatch(cls, f"format override is {format_override}, not T5Encoder") + + if type_override is ModelType.T5Encoder and format_override is ModelFormat.T5Encoder: + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + # Heuristic: Look for the T5EncoderModel class name in the config + try: + config = load_json(mod.path / "text_encoder_2" / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e + + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e + + if config_class_name != "T5EncoderModel": + raise NotAMatch(cls, "model class is not T5EncoderModel") + + # Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models) + has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists() + + if not has_unquantized_config: + raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json") + + return cls(**fields) + class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b @@ -498,10 +570,98 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): return MatchCertainty.NEVER + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.T5Encoder: + raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder") + + if format_override is not None and format_override is not ModelFormat.BnbQuantizedLlmInt8b: + raise NotAMatch(cls, f"format override is {format_override}, not BnbQuantizedLlmInt8b") + + if type_override is ModelType.T5Encoder and format_override is ModelFormat.BnbQuantizedLlmInt8b: + return cls(**fields) + + # Heuristic: Look for the T5EncoderModel class name in the config + try: + config = load_json(mod.path / "text_encoder_2" / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e + + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e + + if config_class_name != "T5EncoderModel": + raise NotAMatch(cls, "model class is not T5EncoderModel") + + # Heuristic: look for the quantization in the filename name + filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) + + # Heuristic: Look for the presence of "SCB" suffixes in state dict keys + has_scb_key_suffix = mod.has_keys_ending_with("SCB") + + if not filename_looks_like_bnb and not has_scb_key_suffix: + raise NotAMatch(cls, "missing bnb quantization indicators") + + return cls(**fields) + class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.OMI] = ModelFormat.OMI + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.LoRA: + raise NotAMatch(cls, f"type override is {type_override}, not LoRA") + + if format_override is not None and format_override is not ModelFormat.OMI: + raise NotAMatch(cls, f"format override is {format_override}, not OMI") + + if type_override is ModelType.LoRA and format_override is ModelFormat.OMI: + return cls(**fields) + + # Heuristic: OMI LoRAs are always files, never directories + if mod.path.is_dir(): + raise NotAMatch(cls, "model path is a directory, not a file") + + # Heuristic: differential diagnosis vs ControlLoRA and Diffusers + if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA") + + # Heuristic: Look for OMI LoRA metadata + metadata = mod.metadata() + is_omi_lora_heuristic = ( + bool(metadata.get("modelspec.sai_model_spec")) + and metadata.get("ot_branch") == "omi_format" + and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora" + ) + + if not is_omi_lora_heuristic: + raise NotAMatch(cls, "model does not match OMI LoRA heuristics") + + base = fields.get("base") or cls.get_base_or_raise(mod) + + return cls(**fields, base=base) + + @classmethod + def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType + metadata = mod.metadata() + architecture = metadata["modelspec.architecture"] + + if architecture == stable_diffusion_xl_1_lora: + return BaseModelType.StableDiffusionXL + elif architecture == flux_dev_1_lora: + return BaseModelType.Flux + else: + raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}") + @classmethod def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: is_lora_override = overrides.get("type") is ModelType.LoRA @@ -608,6 +768,54 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): "base": cls.base_model(mod), } + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.LoRA: + raise NotAMatch(cls, f"type override is {type_override}, not LoRA") + + if format_override is not None and format_override is not ModelFormat.LyCORIS: + raise NotAMatch(cls, f"format override is {format_override}, not LyCORIS") + + if type_override is ModelType.LoRA and format_override is ModelFormat.LyCORIS: + return cls(**fields) + + # Heuristic: LyCORIS LoRAs are always files, never directories + if mod.path.is_dir(): + raise NotAMatch(cls, "model path is a directory, not a file") + + # Heuristic: differential diagnosis vs ControlLoRA and Diffusers + if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA") + + # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. + # Some main models have these keys, likely due to the creator merging in a LoRA. + has_key_with_lora_prefix = mod.has_keys_starting_with( + { + "lora_te_", + "lora_unet_", + "lora_te1_", + "lora_te2_", + "lora_transformer_", + } + ) + + has_key_with_lora_suffix = mod.has_keys_ending_with( + { + "to_k_lora.up.weight", + "to_q_lora.down.weight", + "lora_A.weight", + "lora_B.weight", + } + ) + + if not has_key_with_lora_prefix and not has_key_with_lora_suffix: + raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics") + + return cls(**fields) + class ControlAdapterConfigBase(ABC, BaseModel): default_settings: Optional[ControlAdapterDefaultSettings] = Field( @@ -669,6 +877,35 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): "base": cls.base_model(mod), } + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.LoRA: + raise NotAMatch(cls, f"type override is {type_override}, not LoRA") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if type_override is ModelType.LoRA and format_override is ModelFormat.Diffusers: + return cls(**fields) + + # Heuristic: Diffusers LoRAs are always directories, never files + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers + + suffixes = ["bin", "safetensors"] + weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] + has_lora_weight_file = any(wf.exists() for wf in weight_files) + + if not is_flux_lora_diffusers and not has_lora_weight_file: + raise NotAMatch(cls, "model does not match Diffusers LoRA heuristics") + + return cls(**fields) + class VAEConfigBase(ABC, BaseModel): type: Literal[ModelType.VAE] = ModelType.VAE @@ -729,6 +966,43 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase): raise InvalidModelConfigException("Cannot determine base type") + @classmethod + def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name + for regexp, basetype in [ + (r"xl", BaseModelType.StableDiffusionXL), + (r"sd2", BaseModelType.StableDiffusion2), + (r"vae", BaseModelType.StableDiffusion1), + (r"FLUX.1-schnell_ae", BaseModelType.Flux), + ]: + if re.search(regexp, mod.path.name, re.IGNORECASE): + return basetype + + raise NotAMatch(cls, "cannot determine base type") + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.VAE: + raise NotAMatch(cls, f"type override is {type_override}, not VAE") + + if format_override is not None and format_override is not ModelFormat.Checkpoint: + raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint") + + if type_override is ModelType.VAE and format_override is ModelFormat.Checkpoint: + return cls(**fields) + + if mod.path.is_dir(): + raise NotAMatch(cls, "model path is a directory, not a file") + + if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}): + raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") + + base = fields.get("base") or cls.get_base_or_raise(mod) + return cls(**fields, base=base) + class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" @@ -799,6 +1073,49 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): name = mod.path.parent.name return name + @classmethod + def get_base(cls, mod: ModelOnDisk) -> BaseModelType: + if cls._config_looks_like_sdxl(mod): + return BaseModelType.StableDiffusionXL + elif cls._name_looks_like_sdxl(mod): + return BaseModelType.StableDiffusionXL + else: + # TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO. + return BaseModelType.StableDiffusion1 + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.VAE: + raise NotAMatch(cls, f"type override is {type_override}, not VAE") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if type_override is ModelType.VAE and format_override is ModelFormat.Diffusers: + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + try: + config = load_json(mod.path / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load config.json") from e + + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from config") from e + + if config_class_name not in cls.CLASS_NAMES: + raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}") + + base = fields.get("base") or cls.get_base(mod) + return cls(**fields, base=base) + class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for ControlNet models (diffusers version).""" @@ -875,6 +1192,35 @@ class TextualInversionConfigBase(ABC, BaseModel): raise InvalidModelConfigException(f"{p}: Could not determine base type") + @classmethod + def get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: + p = path or mod.path + + try: + state_dict = mod.load_state_dict(p) + if "string_to_token" in state_dict: + token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] + elif "emb_params" in state_dict: + token_dim = state_dict["emb_params"].shape[-1] + elif "clip_g" in state_dict: + token_dim = state_dict["clip_g"].shape[-1] + else: + token_dim = list(state_dict.values())[0].shape[0] + + match token_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL + case _: + pass + except Exception: + pass + + raise InvalidModelConfigException(f"{p}: Could not determine base type") + class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): """Model config for textual inversion embeddings.""" @@ -911,6 +1257,29 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.TextualInversion: + raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion") + + if format_override is not None and format_override is not ModelFormat.EmbeddingFile: + raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFile") + + if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFile: + return cls(**fields) + + if mod.path.is_dir(): + raise NotAMatch(cls, "model path is a directory, not a file") + + if not cls.file_looks_like_embedding(mod): + raise NotAMatch(cls, "model does not look like a textual inversion embedding file") + + base = fields.get("base") or cls.get_base_or_raise(mod) + return cls(**fields, base=base) + class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): """Model config for textual inversion embeddings.""" @@ -949,6 +1318,30 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.TextualInversion: + raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion") + + if format_override is not None and format_override is not ModelFormat.EmbeddingFolder: + raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFolder") + + if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFolder: + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + for p in mod.weight_files(): + if cls.file_looks_like_embedding(mod, p): + base = fields.get("base") or cls.get_base_or_raise(mod, p) + return cls(**fields, base=base) + + raise NotAMatch(cls, "model does not look like a textual inversion embedding folder") + class MainConfigBase(ABC, BaseModel): type: Literal[ModelType.Main] = ModelType.Main @@ -1100,6 +1493,39 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): return MatchCertainty.NEVER + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + variant_override = fields.get("variant") + + if type_override is not None and type_override is not ModelType.CLIPEmbed: + raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if variant_override is not None and variant_override is not ClipVariantType.G: + raise NotAMatch(cls, f"variant override is {variant_override}, not G") + + if ( + type_override is ModelType.CLIPEmbed + and format_override is ModelFormat.Diffusers + and variant_override is ClipVariantType.G + ): + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + is_clip_embed = cls.is_clip_text_encoder(mod) + clip_variant = cls.get_clip_variant_type(mod) + + if not is_clip_embed or clip_variant is not ClipVariantType.G: + raise NotAMatch(cls, "model does not match CLIP-G heuristics") + + return cls(**fields) + class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): """Model config for CLIP-L Embeddings.""" @@ -1130,6 +1556,39 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): return MatchCertainty.NEVER + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + variant_override = fields.get("variant") + + if type_override is not None and type_override is not ModelType.CLIPEmbed: + raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if variant_override is not None and variant_override is not ClipVariantType.L: + raise NotAMatch(cls, f"variant override is {variant_override}, not L") + + if ( + type_override is ModelType.CLIPEmbed + and format_override is ModelFormat.Diffusers + and variant_override is ClipVariantType.L + ): + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + is_clip_embed = cls.is_clip_text_encoder(mod) + clip_variant = cls.get_clip_variant_type(mod) + + if not is_clip_embed or clip_variant is not ClipVariantType.L: + raise NotAMatch(cls, "model does not match CLIP-L heuristics") + + return cls(**fields) + class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for CLIPVision.""" @@ -1183,6 +1642,46 @@ class SpandrelImageToImageConfig(ModelConfigBase): def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: return {} + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + base_override = fields.get("base") + + if type_override is not None and type_override is not ModelType.SpandrelImageToImage: + raise NotAMatch(cls, f"type override is {type_override}, not SpandrelImageToImage") + + if format_override is not None and format_override is not ModelFormat.Checkpoint: + raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint") + + if base_override is not None and base_override is not BaseModelType.Any: + raise NotAMatch(cls, f"base override is {base_override}, not Any") + + if ( + type_override is ModelType.SpandrelImageToImage + and format_override is ModelFormat.Checkpoint + and base_override is BaseModelType.Any + ): + return cls(**fields) + + if not mod.path.is_file(): + raise NotAMatch(cls, "model path is a directory, not a file") + + try: + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were + # explored to avoid this: + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta + # device. Unfortunately, some Spandrel models perform operations during initialization that are not + # supported on meta tensors. + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to + # maintain it, and the risk of false positive detections is higher. + SpandrelImageToImageModel.load_from_file(mod.path) + base = fields.get("base") or BaseModelType.Any + return cls(**fields, base=base) + except Exception as e: + raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e + class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for SigLIP.""" @@ -1202,6 +1701,8 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): """Model config for Llava Onevision models.""" type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision + base: Literal[BaseModelType.Any] = BaseModelType.Any + variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal @classmethod def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: @@ -1234,6 +1735,41 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): "variant": ModelVariantType.Normal, } + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.LlavaOnevision: + raise NotAMatch(cls, f"type override is {type_override}, not LlavaOnevision") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if type_override is ModelType.LlavaOnevision and format_override is ModelFormat.Diffusers: + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + # Heuristic: Look for the LlavaOnevisionForConditionalGeneration class name in the config + try: + config = load_json(mod.path / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load config.json") from e + + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from config.json") from e + + if config_class_name != "LlavaOnevisionForConditionalGeneration": + raise NotAMatch(cls, "model class is not LlavaOnevisionForConditionalGeneration") + + base = fields.get("base") or BaseModelType.Any + variant = fields.get("variant") or ModelVariantType.Normal + return cls(**fields, base=base, variant=variant) + class ApiModelConfig(MainConfigBase, ModelConfigBase): """Model config for API-based models.""" @@ -1249,6 +1785,9 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase): def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: raise NotImplementedError("API models are not parsed from disk.") + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + raise NotAMatch(cls, "API models cannot be built from disk") class VideoApiModelConfig(VideoConfigBase, ModelConfigBase): """Model config for API-based video models.""" @@ -1264,6 +1803,10 @@ class VideoApiModelConfig(VideoConfigBase, ModelConfigBase): def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: raise NotImplementedError("API models are not parsed from disk.") + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + raise NotAMatch(cls, "API models cannot be built from disk") + def get_model_discriminator_value(v: Any) -> str: """ @@ -1342,6 +1885,15 @@ AnyModelConfig = Annotated[ AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings] +@dataclass +class ModelClassificationResultSuccess: + model: AnyModelConfig + +@dataclass +class ModelClassificationResultFailure: + error: Exception + +ModelClassificationResult = ModelClassificationResultSuccess | ModelClassificationResultFailure class ModelConfigFactory: @staticmethod @@ -1352,3 +1904,108 @@ class ModelConfigFactory: model.converted_at = timestamp validate_hash(model.hash) return model + + @staticmethod + def build_common_fields( + mod: ModelOnDisk, + overrides: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Builds the common fields for all model configs. + + Args: + mod: The model on disk to extract fields from. + overrides: A optional dictionary of fields to override. These fields will take precedence over the values + extracted from the model on disk. + + - Casts string fields to their Enum types. + - Does not validate the fields against the model config schema. + """ + + _overrides: dict[str, Any] = overrides or {} + fields: dict[str, Any] = {} + + if "type" in _overrides: + fields["type"] = ModelType(_overrides["type"]) + + if "format" in _overrides: + fields["format"] = ModelFormat(_overrides["format"]) + + if "base" in _overrides: + fields["base"] = BaseModelType(_overrides["base"]) + + if "source_type" in _overrides: + fields["source_type"] = ModelSourceType(_overrides["source_type"]) + + if "variant" in _overrides: + fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"]) + + fields["path"] = mod.path.as_posix() + fields["source"] = _overrides.get("source") or fields["path"] + fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path + fields["name"] = _overrides.get("name") or mod.name + fields["hash"] = _overrides.get("hash") or mod.hash() + fields["key"] = _overrides.get("key") or uuid_string() + fields["description"] = _overrides.get("description") + fields["repo_variant"] = _overrides.get("repo_variant") or mod.repo_variant() + fields["file_size"] = _overrides.get("file_size") or mod.size() + + return fields + + @staticmethod + def from_model_on_disk( + mod: str | Path | ModelOnDisk, + overrides: dict[str, Any] | None = None, + hash_algo: HASHING_ALGORITHMS = "blake3_single", + ) -> AnyModelConfig: + """ + Returns the best matching ModelConfig instance from a model's file/folder path. + Raises InvalidModelConfigException if no valid configuration is found. + Created to deprecate ModelProbe.probe + """ + if isinstance(mod, Path | str): + mod = ModelOnDisk(Path(mod), hash_algo) + + # We will always need these fields to build any model config. + fields = ModelConfigFactory.build_common_fields(mod, overrides) + + # Store results as a mapping of config class to either an instance of that class or an exception + # that was raised when trying to build it. + results: dict[type[AnyModelConfig], AnyModelConfig | Exception] = {} + + # Try to build an instance of each model config class that uses the classify API. + # Each class will either return an instance of itself or raise NotAMatch if it doesn't match. + # Other exceptions may be raised if something unexpected happens during matching or building. + for config_class in ModelConfigBase.USING_CLASSIFY_API: + try: + instance = config_class.from_model_on_disk(mod, fields) + results[config_class] = instance + except NotAMatch as e: + results[config_class] = e + logger.debug(f"No match for {config_class.__name__} on model {mod.name}") + except ValidationError as e: + # This means the model matched, but we couldn't create the pydantic model instance for the config. + # Maybe invalid overrides were provided? + results[config_class] = e + logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}") + except Exception as e: + results[config_class] = e + logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}") + + matches = [r for r in results.values() if isinstance(r, ModelConfigBase)] + + if not matches and app_config.allow_unknown_models: + logger.warning(f"Unable to identify model {mod.name}, classifying as UnknownModelConfig") + return UnknownModelConfig.from_model_on_disk(mod, fields) + + instance = next(iter(matches)) + if len(matches) > 1: + # TODO(psyche): When we get multiple matches, at most only 1 will be correct. We should disambiguate the + # matches, probably on a case-by-case basis. + # + # One known case is certain SD main (pipeline) models can look like a LoRA. This could happen if the model + # contains merged in LoRA weights. + logger.warning( + f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(instance).__name__}." + ) + logger.info(f"Model {mod.name} classified as {type(instance).__name__}") + return instance