refactor(mm): simplify model classification process

Previously, we had a multi-phase strategy to identify models from their
files on disk:
1. Run each model config classes' `matches()` method on the files. It
checks if the model could possibly be an identified as the candidate
model type. This was intended to be a quick check. Break on the first
match.
2. If we have a match, run the config class's `parse()` method. It
derive some additional model config attrs from the model files. This was
intended to encapsulate heavier operations that may require loading the
model into memory.
3. Derive the common model config attrs, like name, description,
calculate the hash, etc. Some of these are also heavier operations.

This strategy has some issues:
- It is not clear how the pieces fit together. There is some
back-and-forth between different methods and the config base class. It
is hard to trace the flow of logic until you fully wrap your head around
the system and therefore difficult to add a model architecture to the
probe.
- The assumption that we could do quick, lightweight checks before
heavier checks is incorrect. We often _must_ load the model state dict
in the `matches()` method. So there is no practical perf benefit to
splitting up the responsibility of `matches()` and `parse()`.
- Sometimes we need to do the same checks in `matches()` and `parse()`.
In these cases, splitting the logic is has a negative perf impact
because we are doing the same work twice.
- As we introduce the concept of an "unknown" model config (i.e. a model
that we cannot identify, but still record in the db; see #8582), we will
_always_ run _all_ the checks for every model. Therefore we need not try
to defer heavier checks or resource-intensive ops like hashing. We are
going to do them anyways.
- There are situations where a model may match multiple configs. One
known case are SD pipeline models with merged LoRAs. In the old probe
API, we relied on the implicit order of checks to know that if a model
matched for pipeline _and_ LoRA, we prefer the pipeline match. But, in
the new API, we do not have this implicit ordering of checks. To resolve
this in a resilient way, we need to get all matches up front, then use
tie-breaker logic to figure out which should win (or add "differential
diagnosis" logic to the matchers).
- Field overrides weren't handled well by this strategy. They were only
applied at the very end, if a model matched successfully. This means we
cannot tell the system "Hey, this model is type X with base Y. Trust me
bro.". We cannot override the match logic. As we move towards letting
users correct mis-identified models (see #8582), this is a requirement.

We can simplify the process significantly and better support "unknown"
models.

Firstly, model config classes now have a single `from_model_on_disk()`
method that attempts to construct an instance of the class from the
model files. This replaces the `matches()` and `parse()` methods.

If we fail to create the config instance, a special exception is raised
that indicates why we think the files cannot be identified as the given
model config class.

Next, the flow for model identification is a bit simpler:
- Derive all the common fields up-front (name, desc, hash, etc).
- Merge in overrides.
- Call `from_model_on_disk()` for every config class, passing in the
fields. Overrides are handled in this method.
- Record the results for each config class and choose the best one.

The identification logic is a bit more verbose, with the special
exceptions and handling of overrides, but it is very clear what is
happening.

The one downside I can think of for this strategy is we do need to check
every model type, instead of stopping at the first match. It's a bit
less efficient. In practice, however, this isn't a hot code path, and
the improved clarity is worth far more than perf optimizations that the
end user will likely never notice.
This commit is contained in:
psychedelicious
2025-09-24 16:33:03 +10:00
parent 0fd58681a2
commit 8399de9c25
2 changed files with 663 additions and 6 deletions

View File

@@ -612,7 +612,7 @@ class ModelInstallService(ModelInstallServiceBase):
try:
return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(mod=model_path, hash_algo=hash_algo, **fields)
return ModelConfigBase.classify(mod=model_path, fields=deepcopy(fields), hash_algo=hash_algo)
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

View File

@@ -21,6 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
"""
# pyright: reportIncompatibleVariableOverride=false
from dataclasses import dataclass
import json
import logging
import re
@@ -29,11 +30,19 @@ from abc import ABC, abstractmethod
from enum import Enum
from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
from typing import (
ClassVar,
Literal,
Optional,
Self,
Type,
TypeAlias,
Union,
)
import spandrel
import torch
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError
from typing_extensions import Annotated, Any, Dict
from invokeai.app.services.config.config_default import get_config
@@ -71,6 +80,18 @@ class InvalidModelConfigException(Exception):
pass
class NotAMatch(Exception):
"""Exception for when a model does not match a config class.
Args:
config_class: The config class that was being tested.
reason: The reason why the model did not match.
"""
def __init__(self, config_class: "Type[ModelConfigBase]", reason: str):
super().__init__(f"{config_class.__name__} does not match: {reason}")
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
@@ -190,8 +211,8 @@ class ModelConfigBase(ABC, BaseModel):
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
USING_LEGACY_PROBE: ClassVar[set[Type["ModelConfigBase"]]] = set()
USING_CLASSIFY_API: ClassVar[set[Type["ModelConfigBase"]]] = set()
USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set()
USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
def __init_subclass__(cls, **kwargs):
@@ -289,6 +310,13 @@ class ModelConfigBase(ABC, BaseModel):
Returns a MatchCertainty score."""
pass
@classmethod
@abstractmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
"""Performs a quick check to determine if the config matches the model.
Returns a MatchCertainty score."""
pass
@staticmethod
def cast_overrides(**overrides):
"""Casts user overrides from str to Enum"""
@@ -308,7 +336,7 @@ class ModelConfigBase(ABC, BaseModel):
overrides["variant"] = variant_type_adapter.validate_strings(overrides["variant"])
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
def from_model_on_disk_2(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
fields = cls.parse(mod)
cls.cast_overrides(**overrides)
@@ -424,6 +452,11 @@ class T5EncoderConfigBase(ABC, BaseModel):
return {}
def load_json(path: Path) -> dict[str, Any]:
with open(path, "r") as file:
return json.load(file)
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
@@ -456,6 +489,45 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
return MatchCertainty.NEVER
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.T5Encoder:
raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder")
if format_override is not None and format_override is not ModelFormat.T5Encoder:
raise NotAMatch(cls, f"format override is {format_override}, not T5Encoder")
if type_override is ModelType.T5Encoder and format_override is ModelFormat.T5Encoder:
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
# Heuristic: Look for the T5EncoderModel class name in the config
try:
config = load_json(mod.path / "text_encoder_2" / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e
if config_class_name != "T5EncoderModel":
raise NotAMatch(cls, "model class is not T5EncoderModel")
# Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models)
has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
if not has_unquantized_config:
raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json")
return cls(**fields)
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
@@ -498,10 +570,98 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
return MatchCertainty.NEVER
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.T5Encoder:
raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder")
if format_override is not None and format_override is not ModelFormat.BnbQuantizedLlmInt8b:
raise NotAMatch(cls, f"format override is {format_override}, not BnbQuantizedLlmInt8b")
if type_override is ModelType.T5Encoder and format_override is ModelFormat.BnbQuantizedLlmInt8b:
return cls(**fields)
# Heuristic: Look for the T5EncoderModel class name in the config
try:
config = load_json(mod.path / "text_encoder_2" / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e
if config_class_name != "T5EncoderModel":
raise NotAMatch(cls, "model class is not T5EncoderModel")
# Heuristic: look for the quantization in the filename name
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
# Heuristic: Look for the presence of "SCB" suffixes in state dict keys
has_scb_key_suffix = mod.has_keys_ending_with("SCB")
if not filename_looks_like_bnb and not has_scb_key_suffix:
raise NotAMatch(cls, "missing bnb quantization indicators")
return cls(**fields)
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.LoRA:
raise NotAMatch(cls, f"type override is {type_override}, not LoRA")
if format_override is not None and format_override is not ModelFormat.OMI:
raise NotAMatch(cls, f"format override is {format_override}, not OMI")
if type_override is ModelType.LoRA and format_override is ModelFormat.OMI:
return cls(**fields)
# Heuristic: OMI LoRAs are always files, never directories
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
# Heuristic: Look for OMI LoRA metadata
metadata = mod.metadata()
is_omi_lora_heuristic = (
bool(metadata.get("modelspec.sai_model_spec"))
and metadata.get("ot_branch") == "omi_format"
and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
)
if not is_omi_lora_heuristic:
raise NotAMatch(cls, "model does not match OMI LoRA heuristics")
base = fields.get("base") or cls.get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
if architecture == stable_diffusion_xl_1_lora:
return BaseModelType.StableDiffusionXL
elif architecture == flux_dev_1_lora:
return BaseModelType.Flux
else:
raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_lora_override = overrides.get("type") is ModelType.LoRA
@@ -608,6 +768,54 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"base": cls.base_model(mod),
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.LoRA:
raise NotAMatch(cls, f"type override is {type_override}, not LoRA")
if format_override is not None and format_override is not ModelFormat.LyCORIS:
raise NotAMatch(cls, f"format override is {format_override}, not LyCORIS")
if type_override is ModelType.LoRA and format_override is ModelFormat.LyCORIS:
return cls(**fields)
# Heuristic: LyCORIS LoRAs are always files, never directories
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = mod.has_keys_starting_with(
{
"lora_te_",
"lora_unet_",
"lora_te1_",
"lora_te2_",
"lora_transformer_",
}
)
has_key_with_lora_suffix = mod.has_keys_ending_with(
{
"to_k_lora.up.weight",
"to_q_lora.down.weight",
"lora_A.weight",
"lora_B.weight",
}
)
if not has_key_with_lora_prefix and not has_key_with_lora_suffix:
raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics")
return cls(**fields)
class ControlAdapterConfigBase(ABC, BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
@@ -669,6 +877,35 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
"base": cls.base_model(mod),
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.LoRA:
raise NotAMatch(cls, f"type override is {type_override}, not LoRA")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if type_override is ModelType.LoRA and format_override is ModelFormat.Diffusers:
return cls(**fields)
# Heuristic: Diffusers LoRAs are always directories, never files
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
has_lora_weight_file = any(wf.exists() for wf in weight_files)
if not is_flux_lora_diffusers and not has_lora_weight_file:
raise NotAMatch(cls, "model does not match Diffusers LoRA heuristics")
return cls(**fields)
class VAEConfigBase(ABC, BaseModel):
type: Literal[ModelType.VAE] = ModelType.VAE
@@ -729,6 +966,43 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
raise InvalidModelConfigException("Cannot determine base type")
@classmethod
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
]:
if re.search(regexp, mod.path.name, re.IGNORECASE):
return basetype
raise NotAMatch(cls, "cannot determine base type")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.VAE:
raise NotAMatch(cls, f"type override is {type_override}, not VAE")
if format_override is not None and format_override is not ModelFormat.Checkpoint:
raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint")
if type_override is ModelType.VAE and format_override is ModelFormat.Checkpoint:
return cls(**fields)
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
base = fields.get("base") or cls.get_base_or_raise(mod)
return cls(**fields, base=base)
class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
@@ -799,6 +1073,49 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
name = mod.path.parent.name
return name
@classmethod
def get_base(cls, mod: ModelOnDisk) -> BaseModelType:
if cls._config_looks_like_sdxl(mod):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
return BaseModelType.StableDiffusionXL
else:
# TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO.
return BaseModelType.StableDiffusion1
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.VAE:
raise NotAMatch(cls, f"type override is {type_override}, not VAE")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if type_override is ModelType.VAE and format_override is ModelFormat.Diffusers:
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
try:
config = load_json(mod.path / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load config.json") from e
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from config") from e
if config_class_name not in cls.CLASS_NAMES:
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
base = fields.get("base") or cls.get_base(mod)
return cls(**fields, base=base)
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
@@ -875,6 +1192,35 @@ class TextualInversionConfigBase(ABC, BaseModel):
raise InvalidModelConfigException(f"{p}: Could not determine base type")
@classmethod
def get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
p = path or mod.path
try:
state_dict = mod.load_state_dict(p)
if "string_to_token" in state_dict:
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
elif "emb_params" in state_dict:
token_dim = state_dict["emb_params"].shape[-1]
elif "clip_g" in state_dict:
token_dim = state_dict["clip_g"].shape[-1]
else:
token_dim = list(state_dict.values())[0].shape[0]
match token_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL
case _:
pass
except Exception:
pass
raise InvalidModelConfigException(f"{p}: Could not determine base type")
class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
"""Model config for textual inversion embeddings."""
@@ -911,6 +1257,29 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.TextualInversion:
raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion")
if format_override is not None and format_override is not ModelFormat.EmbeddingFile:
raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFile")
if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFile:
return cls(**fields)
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
if not cls.file_looks_like_embedding(mod):
raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
base = fields.get("base") or cls.get_base_or_raise(mod)
return cls(**fields, base=base)
class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
"""Model config for textual inversion embeddings."""
@@ -949,6 +1318,30 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.TextualInversion:
raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion")
if format_override is not None and format_override is not ModelFormat.EmbeddingFolder:
raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFolder")
if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFolder:
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
for p in mod.weight_files():
if cls.file_looks_like_embedding(mod, p):
base = fields.get("base") or cls.get_base_or_raise(mod, p)
return cls(**fields, base=base)
raise NotAMatch(cls, "model does not look like a textual inversion embedding folder")
class MainConfigBase(ABC, BaseModel):
type: Literal[ModelType.Main] = ModelType.Main
@@ -1100,6 +1493,39 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
return MatchCertainty.NEVER
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
variant_override = fields.get("variant")
if type_override is not None and type_override is not ModelType.CLIPEmbed:
raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if variant_override is not None and variant_override is not ClipVariantType.G:
raise NotAMatch(cls, f"variant override is {variant_override}, not G")
if (
type_override is ModelType.CLIPEmbed
and format_override is ModelFormat.Diffusers
and variant_override is ClipVariantType.G
):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
is_clip_embed = cls.is_clip_text_encoder(mod)
clip_variant = cls.get_clip_variant_type(mod)
if not is_clip_embed or clip_variant is not ClipVariantType.G:
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
return cls(**fields)
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-L Embeddings."""
@@ -1130,6 +1556,39 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
return MatchCertainty.NEVER
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
variant_override = fields.get("variant")
if type_override is not None and type_override is not ModelType.CLIPEmbed:
raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if variant_override is not None and variant_override is not ClipVariantType.L:
raise NotAMatch(cls, f"variant override is {variant_override}, not L")
if (
type_override is ModelType.CLIPEmbed
and format_override is ModelFormat.Diffusers
and variant_override is ClipVariantType.L
):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
is_clip_embed = cls.is_clip_text_encoder(mod)
clip_variant = cls.get_clip_variant_type(mod)
if not is_clip_embed or clip_variant is not ClipVariantType.L:
raise NotAMatch(cls, "model does not match CLIP-L heuristics")
return cls(**fields)
class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIPVision."""
@@ -1183,6 +1642,46 @@ class SpandrelImageToImageConfig(ModelConfigBase):
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
base_override = fields.get("base")
if type_override is not None and type_override is not ModelType.SpandrelImageToImage:
raise NotAMatch(cls, f"type override is {type_override}, not SpandrelImageToImage")
if format_override is not None and format_override is not ModelFormat.Checkpoint:
raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint")
if base_override is not None and base_override is not BaseModelType.Any:
raise NotAMatch(cls, f"base override is {base_override}, not Any")
if (
type_override is ModelType.SpandrelImageToImage
and format_override is ModelFormat.Checkpoint
and base_override is BaseModelType.Any
):
return cls(**fields)
if not mod.path.is_file():
raise NotAMatch(cls, "model path is a directory, not a file")
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(mod.path)
base = fields.get("base") or BaseModelType.Any
return cls(**fields, base=base)
except Exception as e:
raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e
class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for SigLIP."""
@@ -1202,6 +1701,8 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
base: Literal[BaseModelType.Any] = BaseModelType.Any
variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
@@ -1234,6 +1735,41 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"variant": ModelVariantType.Normal,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.LlavaOnevision:
raise NotAMatch(cls, f"type override is {type_override}, not LlavaOnevision")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if type_override is ModelType.LlavaOnevision and format_override is ModelFormat.Diffusers:
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
# Heuristic: Look for the LlavaOnevisionForConditionalGeneration class name in the config
try:
config = load_json(mod.path / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load config.json") from e
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from config.json") from e
if config_class_name != "LlavaOnevisionForConditionalGeneration":
raise NotAMatch(cls, "model class is not LlavaOnevisionForConditionalGeneration")
base = fields.get("base") or BaseModelType.Any
variant = fields.get("variant") or ModelVariantType.Normal
return cls(**fields, base=base, variant=variant)
class ApiModelConfig(MainConfigBase, ModelConfigBase):
"""Model config for API-based models."""
@@ -1249,6 +1785,9 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase):
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
raise NotImplementedError("API models are not parsed from disk.")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
raise NotAMatch(cls, "API models cannot be built from disk")
class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
"""Model config for API-based video models."""
@@ -1264,6 +1803,10 @@ class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
raise NotImplementedError("API models are not parsed from disk.")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
raise NotAMatch(cls, "API models cannot be built from disk")
def get_model_discriminator_value(v: Any) -> str:
"""
@@ -1342,6 +1885,15 @@ AnyModelConfig = Annotated[
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]
@dataclass
class ModelClassificationResultSuccess:
model: AnyModelConfig
@dataclass
class ModelClassificationResultFailure:
error: Exception
ModelClassificationResult = ModelClassificationResultSuccess | ModelClassificationResultFailure
class ModelConfigFactory:
@staticmethod
@@ -1352,3 +1904,108 @@ class ModelConfigFactory:
model.converted_at = timestamp
validate_hash(model.hash)
return model
@staticmethod
def build_common_fields(
mod: ModelOnDisk,
overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Builds the common fields for all model configs.
Args:
mod: The model on disk to extract fields from.
overrides: A optional dictionary of fields to override. These fields will take precedence over the values
extracted from the model on disk.
- Casts string fields to their Enum types.
- Does not validate the fields against the model config schema.
"""
_overrides: dict[str, Any] = overrides or {}
fields: dict[str, Any] = {}
if "type" in _overrides:
fields["type"] = ModelType(_overrides["type"])
if "format" in _overrides:
fields["format"] = ModelFormat(_overrides["format"])
if "base" in _overrides:
fields["base"] = BaseModelType(_overrides["base"])
if "source_type" in _overrides:
fields["source_type"] = ModelSourceType(_overrides["source_type"])
if "variant" in _overrides:
fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"])
fields["path"] = mod.path.as_posix()
fields["source"] = _overrides.get("source") or fields["path"]
fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path
fields["name"] = _overrides.get("name") or mod.name
fields["hash"] = _overrides.get("hash") or mod.hash()
fields["key"] = _overrides.get("key") or uuid_string()
fields["description"] = _overrides.get("description")
fields["repo_variant"] = _overrides.get("repo_variant") or mod.repo_variant()
fields["file_size"] = _overrides.get("file_size") or mod.size()
return fields
@staticmethod
def from_model_on_disk(
mod: str | Path | ModelOnDisk,
overrides: dict[str, Any] | None = None,
hash_algo: HASHING_ALGORITHMS = "blake3_single",
) -> AnyModelConfig:
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
if isinstance(mod, Path | str):
mod = ModelOnDisk(Path(mod), hash_algo)
# We will always need these fields to build any model config.
fields = ModelConfigFactory.build_common_fields(mod, overrides)
# Store results as a mapping of config class to either an instance of that class or an exception
# that was raised when trying to build it.
results: dict[type[AnyModelConfig], AnyModelConfig | Exception] = {}
# Try to build an instance of each model config class that uses the classify API.
# Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
# Other exceptions may be raised if something unexpected happens during matching or building.
for config_class in ModelConfigBase.USING_CLASSIFY_API:
try:
instance = config_class.from_model_on_disk(mod, fields)
results[config_class] = instance
except NotAMatch as e:
results[config_class] = e
logger.debug(f"No match for {config_class.__name__} on model {mod.name}")
except ValidationError as e:
# This means the model matched, but we couldn't create the pydantic model instance for the config.
# Maybe invalid overrides were provided?
results[config_class] = e
logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}")
except Exception as e:
results[config_class] = e
logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
matches = [r for r in results.values() if isinstance(r, ModelConfigBase)]
if not matches and app_config.allow_unknown_models:
logger.warning(f"Unable to identify model {mod.name}, classifying as UnknownModelConfig")
return UnknownModelConfig.from_model_on_disk(mod, fields)
instance = next(iter(matches))
if len(matches) > 1:
# TODO(psyche): When we get multiple matches, at most only 1 will be correct. We should disambiguate the
# matches, probably on a case-by-case basis.
#
# One known case is certain SD main (pipeline) models can look like a LoRA. This could happen if the model
# contains merged in LoRA weights.
logger.warning(
f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(instance).__name__}."
)
logger.info(f"Model {mod.name} classified as {type(instance).__name__}")
return instance