diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 5bc9af8e6b..cae423afe3 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -612,7 +612,7 @@ class ModelInstallService(ModelInstallServiceBase): try: return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore except InvalidModelConfigException: - return ModelConfigBase.classify(mod=model_path, hash_algo=hash_algo, **fields) + return ModelConfigBase.classify(mod=model_path, fields=deepcopy(fields), hash_algo=hash_algo) def _register( self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 9ab1cdcc0f..a5c2058e1b 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -21,6 +21,7 @@ Validation errors will raise an InvalidModelConfigException error. """ # pyright: reportIncompatibleVariableOverride=false +from dataclasses import dataclass import json import logging import re @@ -29,11 +30,19 @@ from abc import ABC, abstractmethod from enum import Enum from inspect import isabstract from pathlib import Path -from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union +from typing import ( + ClassVar, + Literal, + Optional, + Self, + Type, + TypeAlias, + Union, +) import spandrel import torch -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter +from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError from typing_extensions import Annotated, Any, Dict from invokeai.app.services.config.config_default import get_config @@ -71,6 +80,18 @@ class InvalidModelConfigException(Exception): pass +class NotAMatch(Exception): + """Exception for when a model does not match a config class. + + Args: + config_class: The config class that was being tested. + reason: The reason why the model did not match. + """ + + def __init__(self, config_class: "Type[ModelConfigBase]", reason: str): + super().__init__(f"{config_class.__name__} does not match: {reason}") + + DEFAULTS_PRECISION = Literal["fp16", "fp32"] @@ -190,8 +211,8 @@ class ModelConfigBase(ABC, BaseModel): ) usage_info: Optional[str] = Field(default=None, description="Usage information for this model") - USING_LEGACY_PROBE: ClassVar[set[Type["ModelConfigBase"]]] = set() - USING_CLASSIFY_API: ClassVar[set[Type["ModelConfigBase"]]] = set() + USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set() + USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set() _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED def __init_subclass__(cls, **kwargs): @@ -289,6 +310,13 @@ class ModelConfigBase(ABC, BaseModel): Returns a MatchCertainty score.""" pass + @classmethod + @abstractmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + """Performs a quick check to determine if the config matches the model. + Returns a MatchCertainty score.""" + pass + @staticmethod def cast_overrides(**overrides): """Casts user overrides from str to Enum""" @@ -308,7 +336,7 @@ class ModelConfigBase(ABC, BaseModel): overrides["variant"] = variant_type_adapter.validate_strings(overrides["variant"]) @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, **overrides): + def from_model_on_disk_2(cls, mod: ModelOnDisk, **overrides): """Creates an instance of this config or raises InvalidModelConfigException.""" fields = cls.parse(mod) cls.cast_overrides(**overrides) @@ -424,6 +452,11 @@ class T5EncoderConfigBase(ABC, BaseModel): return {} +def load_json(path: Path) -> dict[str, Any]: + with open(path, "r") as file: + return json.load(file) + + class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder @@ -456,6 +489,45 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): return MatchCertainty.NEVER + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.T5Encoder: + raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder") + + if format_override is not None and format_override is not ModelFormat.T5Encoder: + raise NotAMatch(cls, f"format override is {format_override}, not T5Encoder") + + if type_override is ModelType.T5Encoder and format_override is ModelFormat.T5Encoder: + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + # Heuristic: Look for the T5EncoderModel class name in the config + try: + config = load_json(mod.path / "text_encoder_2" / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e + + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e + + if config_class_name != "T5EncoderModel": + raise NotAMatch(cls, "model class is not T5EncoderModel") + + # Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models) + has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists() + + if not has_unquantized_config: + raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json") + + return cls(**fields) + class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b @@ -498,10 +570,98 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): return MatchCertainty.NEVER + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.T5Encoder: + raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder") + + if format_override is not None and format_override is not ModelFormat.BnbQuantizedLlmInt8b: + raise NotAMatch(cls, f"format override is {format_override}, not BnbQuantizedLlmInt8b") + + if type_override is ModelType.T5Encoder and format_override is ModelFormat.BnbQuantizedLlmInt8b: + return cls(**fields) + + # Heuristic: Look for the T5EncoderModel class name in the config + try: + config = load_json(mod.path / "text_encoder_2" / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e + + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e + + if config_class_name != "T5EncoderModel": + raise NotAMatch(cls, "model class is not T5EncoderModel") + + # Heuristic: look for the quantization in the filename name + filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) + + # Heuristic: Look for the presence of "SCB" suffixes in state dict keys + has_scb_key_suffix = mod.has_keys_ending_with("SCB") + + if not filename_looks_like_bnb and not has_scb_key_suffix: + raise NotAMatch(cls, "missing bnb quantization indicators") + + return cls(**fields) + class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.OMI] = ModelFormat.OMI + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.LoRA: + raise NotAMatch(cls, f"type override is {type_override}, not LoRA") + + if format_override is not None and format_override is not ModelFormat.OMI: + raise NotAMatch(cls, f"format override is {format_override}, not OMI") + + if type_override is ModelType.LoRA and format_override is ModelFormat.OMI: + return cls(**fields) + + # Heuristic: OMI LoRAs are always files, never directories + if mod.path.is_dir(): + raise NotAMatch(cls, "model path is a directory, not a file") + + # Heuristic: differential diagnosis vs ControlLoRA and Diffusers + if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA") + + # Heuristic: Look for OMI LoRA metadata + metadata = mod.metadata() + is_omi_lora_heuristic = ( + bool(metadata.get("modelspec.sai_model_spec")) + and metadata.get("ot_branch") == "omi_format" + and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora" + ) + + if not is_omi_lora_heuristic: + raise NotAMatch(cls, "model does not match OMI LoRA heuristics") + + base = fields.get("base") or cls.get_base_or_raise(mod) + + return cls(**fields, base=base) + + @classmethod + def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType + metadata = mod.metadata() + architecture = metadata["modelspec.architecture"] + + if architecture == stable_diffusion_xl_1_lora: + return BaseModelType.StableDiffusionXL + elif architecture == flux_dev_1_lora: + return BaseModelType.Flux + else: + raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}") + @classmethod def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: is_lora_override = overrides.get("type") is ModelType.LoRA @@ -608,6 +768,54 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): "base": cls.base_model(mod), } + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.LoRA: + raise NotAMatch(cls, f"type override is {type_override}, not LoRA") + + if format_override is not None and format_override is not ModelFormat.LyCORIS: + raise NotAMatch(cls, f"format override is {format_override}, not LyCORIS") + + if type_override is ModelType.LoRA and format_override is ModelFormat.LyCORIS: + return cls(**fields) + + # Heuristic: LyCORIS LoRAs are always files, never directories + if mod.path.is_dir(): + raise NotAMatch(cls, "model path is a directory, not a file") + + # Heuristic: differential diagnosis vs ControlLoRA and Diffusers + if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA") + + # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. + # Some main models have these keys, likely due to the creator merging in a LoRA. + has_key_with_lora_prefix = mod.has_keys_starting_with( + { + "lora_te_", + "lora_unet_", + "lora_te1_", + "lora_te2_", + "lora_transformer_", + } + ) + + has_key_with_lora_suffix = mod.has_keys_ending_with( + { + "to_k_lora.up.weight", + "to_q_lora.down.weight", + "lora_A.weight", + "lora_B.weight", + } + ) + + if not has_key_with_lora_prefix and not has_key_with_lora_suffix: + raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics") + + return cls(**fields) + class ControlAdapterConfigBase(ABC, BaseModel): default_settings: Optional[ControlAdapterDefaultSettings] = Field( @@ -669,6 +877,35 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): "base": cls.base_model(mod), } + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.LoRA: + raise NotAMatch(cls, f"type override is {type_override}, not LoRA") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if type_override is ModelType.LoRA and format_override is ModelFormat.Diffusers: + return cls(**fields) + + # Heuristic: Diffusers LoRAs are always directories, never files + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers + + suffixes = ["bin", "safetensors"] + weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] + has_lora_weight_file = any(wf.exists() for wf in weight_files) + + if not is_flux_lora_diffusers and not has_lora_weight_file: + raise NotAMatch(cls, "model does not match Diffusers LoRA heuristics") + + return cls(**fields) + class VAEConfigBase(ABC, BaseModel): type: Literal[ModelType.VAE] = ModelType.VAE @@ -729,6 +966,43 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase): raise InvalidModelConfigException("Cannot determine base type") + @classmethod + def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name + for regexp, basetype in [ + (r"xl", BaseModelType.StableDiffusionXL), + (r"sd2", BaseModelType.StableDiffusion2), + (r"vae", BaseModelType.StableDiffusion1), + (r"FLUX.1-schnell_ae", BaseModelType.Flux), + ]: + if re.search(regexp, mod.path.name, re.IGNORECASE): + return basetype + + raise NotAMatch(cls, "cannot determine base type") + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.VAE: + raise NotAMatch(cls, f"type override is {type_override}, not VAE") + + if format_override is not None and format_override is not ModelFormat.Checkpoint: + raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint") + + if type_override is ModelType.VAE and format_override is ModelFormat.Checkpoint: + return cls(**fields) + + if mod.path.is_dir(): + raise NotAMatch(cls, "model path is a directory, not a file") + + if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}): + raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") + + base = fields.get("base") or cls.get_base_or_raise(mod) + return cls(**fields, base=base) + class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" @@ -799,6 +1073,49 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): name = mod.path.parent.name return name + @classmethod + def get_base(cls, mod: ModelOnDisk) -> BaseModelType: + if cls._config_looks_like_sdxl(mod): + return BaseModelType.StableDiffusionXL + elif cls._name_looks_like_sdxl(mod): + return BaseModelType.StableDiffusionXL + else: + # TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO. + return BaseModelType.StableDiffusion1 + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.VAE: + raise NotAMatch(cls, f"type override is {type_override}, not VAE") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if type_override is ModelType.VAE and format_override is ModelFormat.Diffusers: + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + try: + config = load_json(mod.path / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load config.json") from e + + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from config") from e + + if config_class_name not in cls.CLASS_NAMES: + raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}") + + base = fields.get("base") or cls.get_base(mod) + return cls(**fields, base=base) + class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for ControlNet models (diffusers version).""" @@ -875,6 +1192,35 @@ class TextualInversionConfigBase(ABC, BaseModel): raise InvalidModelConfigException(f"{p}: Could not determine base type") + @classmethod + def get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: + p = path or mod.path + + try: + state_dict = mod.load_state_dict(p) + if "string_to_token" in state_dict: + token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] + elif "emb_params" in state_dict: + token_dim = state_dict["emb_params"].shape[-1] + elif "clip_g" in state_dict: + token_dim = state_dict["clip_g"].shape[-1] + else: + token_dim = list(state_dict.values())[0].shape[0] + + match token_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL + case _: + pass + except Exception: + pass + + raise InvalidModelConfigException(f"{p}: Could not determine base type") + class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): """Model config for textual inversion embeddings.""" @@ -911,6 +1257,29 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.TextualInversion: + raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion") + + if format_override is not None and format_override is not ModelFormat.EmbeddingFile: + raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFile") + + if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFile: + return cls(**fields) + + if mod.path.is_dir(): + raise NotAMatch(cls, "model path is a directory, not a file") + + if not cls.file_looks_like_embedding(mod): + raise NotAMatch(cls, "model does not look like a textual inversion embedding file") + + base = fields.get("base") or cls.get_base_or_raise(mod) + return cls(**fields, base=base) + class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): """Model config for textual inversion embeddings.""" @@ -949,6 +1318,30 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.TextualInversion: + raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion") + + if format_override is not None and format_override is not ModelFormat.EmbeddingFolder: + raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFolder") + + if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFolder: + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + for p in mod.weight_files(): + if cls.file_looks_like_embedding(mod, p): + base = fields.get("base") or cls.get_base_or_raise(mod, p) + return cls(**fields, base=base) + + raise NotAMatch(cls, "model does not look like a textual inversion embedding folder") + class MainConfigBase(ABC, BaseModel): type: Literal[ModelType.Main] = ModelType.Main @@ -1100,6 +1493,39 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): return MatchCertainty.NEVER + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + variant_override = fields.get("variant") + + if type_override is not None and type_override is not ModelType.CLIPEmbed: + raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if variant_override is not None and variant_override is not ClipVariantType.G: + raise NotAMatch(cls, f"variant override is {variant_override}, not G") + + if ( + type_override is ModelType.CLIPEmbed + and format_override is ModelFormat.Diffusers + and variant_override is ClipVariantType.G + ): + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + is_clip_embed = cls.is_clip_text_encoder(mod) + clip_variant = cls.get_clip_variant_type(mod) + + if not is_clip_embed or clip_variant is not ClipVariantType.G: + raise NotAMatch(cls, "model does not match CLIP-G heuristics") + + return cls(**fields) + class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): """Model config for CLIP-L Embeddings.""" @@ -1130,6 +1556,39 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): return MatchCertainty.NEVER + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + variant_override = fields.get("variant") + + if type_override is not None and type_override is not ModelType.CLIPEmbed: + raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if variant_override is not None and variant_override is not ClipVariantType.L: + raise NotAMatch(cls, f"variant override is {variant_override}, not L") + + if ( + type_override is ModelType.CLIPEmbed + and format_override is ModelFormat.Diffusers + and variant_override is ClipVariantType.L + ): + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + is_clip_embed = cls.is_clip_text_encoder(mod) + clip_variant = cls.get_clip_variant_type(mod) + + if not is_clip_embed or clip_variant is not ClipVariantType.L: + raise NotAMatch(cls, "model does not match CLIP-L heuristics") + + return cls(**fields) + class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for CLIPVision.""" @@ -1183,6 +1642,46 @@ class SpandrelImageToImageConfig(ModelConfigBase): def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: return {} + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + base_override = fields.get("base") + + if type_override is not None and type_override is not ModelType.SpandrelImageToImage: + raise NotAMatch(cls, f"type override is {type_override}, not SpandrelImageToImage") + + if format_override is not None and format_override is not ModelFormat.Checkpoint: + raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint") + + if base_override is not None and base_override is not BaseModelType.Any: + raise NotAMatch(cls, f"base override is {base_override}, not Any") + + if ( + type_override is ModelType.SpandrelImageToImage + and format_override is ModelFormat.Checkpoint + and base_override is BaseModelType.Any + ): + return cls(**fields) + + if not mod.path.is_file(): + raise NotAMatch(cls, "model path is a directory, not a file") + + try: + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were + # explored to avoid this: + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta + # device. Unfortunately, some Spandrel models perform operations during initialization that are not + # supported on meta tensors. + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to + # maintain it, and the risk of false positive detections is higher. + SpandrelImageToImageModel.load_from_file(mod.path) + base = fields.get("base") or BaseModelType.Any + return cls(**fields, base=base) + except Exception as e: + raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e + class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for SigLIP.""" @@ -1202,6 +1701,8 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): """Model config for Llava Onevision models.""" type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision + base: Literal[BaseModelType.Any] = BaseModelType.Any + variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal @classmethod def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: @@ -1234,6 +1735,41 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): "variant": ModelVariantType.Normal, } + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + type_override = fields.get("type") + format_override = fields.get("format") + + if type_override is not None and type_override is not ModelType.LlavaOnevision: + raise NotAMatch(cls, f"type override is {type_override}, not LlavaOnevision") + + if format_override is not None and format_override is not ModelFormat.Diffusers: + raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") + + if type_override is ModelType.LlavaOnevision and format_override is ModelFormat.Diffusers: + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + # Heuristic: Look for the LlavaOnevisionForConditionalGeneration class name in the config + try: + config = load_json(mod.path / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load config.json") from e + + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from config.json") from e + + if config_class_name != "LlavaOnevisionForConditionalGeneration": + raise NotAMatch(cls, "model class is not LlavaOnevisionForConditionalGeneration") + + base = fields.get("base") or BaseModelType.Any + variant = fields.get("variant") or ModelVariantType.Normal + return cls(**fields, base=base, variant=variant) + class ApiModelConfig(MainConfigBase, ModelConfigBase): """Model config for API-based models.""" @@ -1249,6 +1785,9 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase): def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: raise NotImplementedError("API models are not parsed from disk.") + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + raise NotAMatch(cls, "API models cannot be built from disk") class VideoApiModelConfig(VideoConfigBase, ModelConfigBase): """Model config for API-based video models.""" @@ -1264,6 +1803,10 @@ class VideoApiModelConfig(VideoConfigBase, ModelConfigBase): def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: raise NotImplementedError("API models are not parsed from disk.") + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + raise NotAMatch(cls, "API models cannot be built from disk") + def get_model_discriminator_value(v: Any) -> str: """ @@ -1342,6 +1885,15 @@ AnyModelConfig = Annotated[ AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings] +@dataclass +class ModelClassificationResultSuccess: + model: AnyModelConfig + +@dataclass +class ModelClassificationResultFailure: + error: Exception + +ModelClassificationResult = ModelClassificationResultSuccess | ModelClassificationResultFailure class ModelConfigFactory: @staticmethod @@ -1352,3 +1904,108 @@ class ModelConfigFactory: model.converted_at = timestamp validate_hash(model.hash) return model + + @staticmethod + def build_common_fields( + mod: ModelOnDisk, + overrides: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Builds the common fields for all model configs. + + Args: + mod: The model on disk to extract fields from. + overrides: A optional dictionary of fields to override. These fields will take precedence over the values + extracted from the model on disk. + + - Casts string fields to their Enum types. + - Does not validate the fields against the model config schema. + """ + + _overrides: dict[str, Any] = overrides or {} + fields: dict[str, Any] = {} + + if "type" in _overrides: + fields["type"] = ModelType(_overrides["type"]) + + if "format" in _overrides: + fields["format"] = ModelFormat(_overrides["format"]) + + if "base" in _overrides: + fields["base"] = BaseModelType(_overrides["base"]) + + if "source_type" in _overrides: + fields["source_type"] = ModelSourceType(_overrides["source_type"]) + + if "variant" in _overrides: + fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"]) + + fields["path"] = mod.path.as_posix() + fields["source"] = _overrides.get("source") or fields["path"] + fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path + fields["name"] = _overrides.get("name") or mod.name + fields["hash"] = _overrides.get("hash") or mod.hash() + fields["key"] = _overrides.get("key") or uuid_string() + fields["description"] = _overrides.get("description") + fields["repo_variant"] = _overrides.get("repo_variant") or mod.repo_variant() + fields["file_size"] = _overrides.get("file_size") or mod.size() + + return fields + + @staticmethod + def from_model_on_disk( + mod: str | Path | ModelOnDisk, + overrides: dict[str, Any] | None = None, + hash_algo: HASHING_ALGORITHMS = "blake3_single", + ) -> AnyModelConfig: + """ + Returns the best matching ModelConfig instance from a model's file/folder path. + Raises InvalidModelConfigException if no valid configuration is found. + Created to deprecate ModelProbe.probe + """ + if isinstance(mod, Path | str): + mod = ModelOnDisk(Path(mod), hash_algo) + + # We will always need these fields to build any model config. + fields = ModelConfigFactory.build_common_fields(mod, overrides) + + # Store results as a mapping of config class to either an instance of that class or an exception + # that was raised when trying to build it. + results: dict[type[AnyModelConfig], AnyModelConfig | Exception] = {} + + # Try to build an instance of each model config class that uses the classify API. + # Each class will either return an instance of itself or raise NotAMatch if it doesn't match. + # Other exceptions may be raised if something unexpected happens during matching or building. + for config_class in ModelConfigBase.USING_CLASSIFY_API: + try: + instance = config_class.from_model_on_disk(mod, fields) + results[config_class] = instance + except NotAMatch as e: + results[config_class] = e + logger.debug(f"No match for {config_class.__name__} on model {mod.name}") + except ValidationError as e: + # This means the model matched, but we couldn't create the pydantic model instance for the config. + # Maybe invalid overrides were provided? + results[config_class] = e + logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}") + except Exception as e: + results[config_class] = e + logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}") + + matches = [r for r in results.values() if isinstance(r, ModelConfigBase)] + + if not matches and app_config.allow_unknown_models: + logger.warning(f"Unable to identify model {mod.name}, classifying as UnknownModelConfig") + return UnknownModelConfig.from_model_on_disk(mod, fields) + + instance = next(iter(matches)) + if len(matches) > 1: + # TODO(psyche): When we get multiple matches, at most only 1 will be correct. We should disambiguate the + # matches, probably on a case-by-case basis. + # + # One known case is certain SD main (pipeline) models can look like a LoRA. This could happen if the model + # contains merged in LoRA weights. + logger.warning( + f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(instance).__name__}." + ) + logger.info(f"Model {mod.name} classified as {type(instance).__name__}") + return instance