diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index a5c2058e1b..8f5f7c89f6 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -21,7 +21,6 @@ Validation errors will raise an InvalidModelConfigException error. """ # pyright: reportIncompatibleVariableOverride=false -from dataclasses import dataclass import json import logging import re @@ -40,7 +39,6 @@ from typing import ( Union, ) -import spandrel import torch from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError from typing_extensions import Annotated, Any, Dict @@ -88,7 +86,7 @@ class NotAMatch(Exception): reason: The reason why the model did not match. """ - def __init__(self, config_class: "Type[ModelConfigBase]", reason: str): + def __init__(self, config_class: "Type[AnyModelConfig]", reason: str): super().__init__(f"{config_class.__name__} does not match: {reason}") @@ -104,6 +102,19 @@ def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]: return None +def validate_overrides( + config_class: "Type[AnyModelConfig]", overrides: dict[str, Any], allowed: dict[str, Any] +) -> None: + for key, value in allowed.items(): + if key not in overrides: + continue + if overrides[key] != value: + raise NotAMatch( + config_class, + f"override {key}={overrides[key]} does not match required value {key}={value}", + ) + + class SubmodelDefinition(BaseModel): path_or_prefix: str model_type: ModelType @@ -139,23 +150,6 @@ class ControlAdapterDefaultSettings(BaseModel): model_config = ConfigDict(extra="forbid") -class MatchSpeed(int, Enum): - """Represents the estimated runtime speed of a config's 'matches' method.""" - - FAST = 0 - MED = 1 - SLOW = 2 - - -class MatchCertainty(int, Enum): - """Represents the certainty of a config's 'matches' method.""" - - NEVER = 0 - MAYBE = 1 - EXACT = 2 - OVERRIDE = 3 - - class LegacyProbeMixin: """Mixin for classes using the legacy probe for model classification.""" @@ -213,7 +207,6 @@ class ModelConfigBase(ABC, BaseModel): USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set() USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set() - _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -228,132 +221,20 @@ class ModelConfigBase(ABC, BaseModel): concrete = {cls for cls in subclasses if not isabstract(cls)} return concrete - @staticmethod - def classify( - mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides - ) -> "AnyModelConfig": - """ - Returns the best matching ModelConfig instance from a model's file/folder path. - Raises InvalidModelConfigException if no valid configuration is found. - Created to deprecate ModelProbe.probe - """ - if isinstance(mod, Path | str): - mod = ModelOnDisk(Path(mod), hash_algo) - - candidates = ModelConfigBase.USING_CLASSIFY_API - sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__)) - - overrides = overrides or {} - ModelConfigBase.cast_overrides(**overrides) - - matches: dict[Type[ModelConfigBase], MatchCertainty] = {} - - for config_cls in sorted_by_match_speed: - try: - score = config_cls.matches(mod, **overrides) - - # A score of 0 means "no match" - if score is MatchCertainty.NEVER: - continue - - matches[config_cls] = score - - if score is MatchCertainty.EXACT or score is MatchCertainty.OVERRIDE: - # Perfect match - skip further checks - break - except Exception as e: - logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}") - continue - - if matches: - # Select the config class with the highest score - sorted_by_score = sorted(matches.items(), key=lambda item: item[1].value) - # Check if there are multiple classes with the same top score - top_score = sorted_by_score[-1][1] - top_classes = [cls for cls, score in sorted_by_score if score is top_score] - if len(top_classes) > 1: - logger.warning( - f"Multiple model config classes matched with the same top score ({top_score}) for model {mod.name}: {[cls.__name__ for cls in top_classes]}. Using {top_classes[0].__name__}." - ) - config_cls = top_classes[0] - # Finally, create the config instance - logger.info(f"Model {mod.name} classified as {config_cls.__name__} with score {top_score.name}") - return config_cls.from_model_on_disk(mod, **overrides) - - if app_config.allow_unknown_models: - try: - return UnknownModelConfig.from_model_on_disk(mod, **overrides) - except Exception: - # Fall through to raising the exception below - pass - - raise InvalidModelConfigException("Unable to determine model type") - @classmethod def get_tag(cls) -> Tag: type = cls.model_fields["type"].default.value format = cls.model_fields["format"].default.value return Tag(f"{type}.{format}") - @classmethod - @abstractmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - """Returns a dictionary with the fields needed to construct the model. - Raises InvalidModelConfigException if the model is invalid. - """ - pass - - @classmethod - @abstractmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - """Performs a quick check to determine if the config matches the model. - Returns a MatchCertainty score.""" - pass - @classmethod @abstractmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - """Performs a quick check to determine if the config matches the model. - Returns a MatchCertainty score.""" + """Given the model on disk and any overrides, return an instance of this config class. + + Implementations should raise NotAMatch if the model does not match this config class.""" pass - @staticmethod - def cast_overrides(**overrides): - """Casts user overrides from str to Enum""" - if "type" in overrides: - overrides["type"] = ModelType(overrides["type"]) - - if "format" in overrides: - overrides["format"] = ModelFormat(overrides["format"]) - - if "base" in overrides: - overrides["base"] = BaseModelType(overrides["base"]) - - if "source_type" in overrides: - overrides["source_type"] = ModelSourceType(overrides["source_type"]) - - if "variant" in overrides: - overrides["variant"] = variant_type_adapter.validate_strings(overrides["variant"]) - - @classmethod - def from_model_on_disk_2(cls, mod: ModelOnDisk, **overrides): - """Creates an instance of this config or raises InvalidModelConfigException.""" - fields = cls.parse(mod) - cls.cast_overrides(**overrides) - fields.update(overrides) - - fields["path"] = mod.path.as_posix() - fields["source"] = fields.get("source") or fields["path"] - fields["source_type"] = fields.get("source_type") or ModelSourceType.Path - fields["name"] = fields.get("name") or mod.name - fields["hash"] = fields.get("hash") or mod.hash() - fields["key"] = fields.get("key") or uuid_string() - fields["description"] = fields.get("description") - fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant() - fields["file_size"] = fields.get("file_size") or mod.size() - - return cls(**fields) - class UnknownModelConfig(ModelConfigBase): base: Literal[BaseModelType.Unknown] = BaseModelType.Unknown @@ -361,12 +242,8 @@ class UnknownModelConfig(ModelConfigBase): format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return {} + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + raise NotAMatch(cls, "unknown model config cannot match any model") class CheckpointConfigBase(ABC, BaseModel): @@ -441,16 +318,6 @@ class T5EncoderConfigBase(ABC, BaseModel): base: Literal[BaseModelType.Any] = BaseModelType.Any type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder - @classmethod - def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]: - path = mod.path / "text_encoder_2" / "config.json" - with open(path, "r") as file: - return json.load(file) - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return {} - def load_json(path: Path) -> dict[str, Any]: with open(path, "r") as file: @@ -460,35 +327,6 @@ def load_json(path: Path) -> dict[str, Any]: class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_t5_type_override = overrides.get("type") is ModelType.T5Encoder - is_t5_format_override = overrides.get("format") is ModelFormat.T5Encoder - - if is_t5_type_override and is_t5_format_override: - return MatchCertainty.OVERRIDE - - if mod.path.is_file(): - return MatchCertainty.NEVER - - model_dir = mod.path / "text_encoder_2" - - if not model_dir.exists(): - return MatchCertainty.NEVER - - try: - config = cls.get_config(mod) - - is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel" - is_t5_format = (model_dir / "model.safetensors.index.json").exists() - - if is_t5_encoder_model and is_t5_format: - return MatchCertainty.EXACT - except Exception: - pass - - return MatchCertainty.NEVER - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -532,44 +370,6 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_t5_type_override = overrides.get("type") is ModelType.T5Encoder - is_bnb_format_override = overrides.get("format") is ModelFormat.BnbQuantizedLlmInt8b - - if is_t5_type_override and is_bnb_format_override: - return MatchCertainty.OVERRIDE - - if mod.path.is_file(): - return MatchCertainty.NEVER - - model_dir = mod.path / "text_encoder_2" - - if not model_dir.exists(): - return MatchCertainty.NEVER - - try: - config = cls.get_config(mod) - - is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel" - - # Heuristic: look for the quantization in the name - files = model_dir.glob("*.safetensors") - filename_looks_like_bnb = any(x for x in files if "llm_int8" in x.as_posix()) - - if is_t5_encoder_model and filename_looks_like_bnb: - return MatchCertainty.EXACT - - # Heuristic: Look for the presence of "SCB" in state dict keys (typically a suffix) - has_scb_key = mod.has_keys_ending_with("SCB") - - if is_t5_encoder_model and has_scb_key: - return MatchCertainty.EXACT - except Exception: - pass - - return MatchCertainty.NEVER - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -651,7 +451,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): return cls(**fields, base=base) @classmethod - def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType + def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: metadata = mod.metadata() architecture = metadata["modelspec.architecture"] @@ -662,112 +462,12 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): else: raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}") - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_lora_override = overrides.get("type") is ModelType.LoRA - is_omi_override = overrides.get("format") is ModelFormat.OMI - - # If both type and format are overridden, skip the heuristic checks - if is_lora_override and is_omi_override: - return MatchCertainty.OVERRIDE - - # OMI LoRAs are always files, never directories - if mod.path.is_dir(): - return MatchCertainty.NEVER - - # Avoid false positive match against ControlLoRA and Diffusers - if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - return MatchCertainty.NEVER - - metadata = mod.metadata() - is_omi_lora_heuristic = ( - bool(metadata.get("modelspec.sai_model_spec")) - and metadata.get("ot_branch") == "omi_format" - and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora" - ) - - if is_omi_lora_heuristic: - return MatchCertainty.EXACT - - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - metadata = mod.metadata() - architecture = metadata["modelspec.architecture"] - - if architecture == stable_diffusion_xl_1_lora: - base = BaseModelType.StableDiffusionXL - elif architecture == flux_dev_1_lora: - base = BaseModelType.Flux - else: - raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}") - - return {"base": base} - class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): """Model config for LoRA/Lycoris models.""" format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_lora_override = overrides.get("type") is ModelType.LoRA - is_omi_override = overrides.get("format") is ModelFormat.LyCORIS - - # If both type and format are overridden, skip the heuristic checks and return a perfect score - if is_lora_override and is_omi_override: - return MatchCertainty.OVERRIDE - - # LyCORIS LoRAs are always files, never directories - if mod.path.is_dir(): - return MatchCertainty.NEVER - - # Avoid false positive match against ControlLoRA and Diffusers - if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - return MatchCertainty.NEVER - - state_dict = mod.load_state_dict() - for key in state_dict.keys(): - if isinstance(key, int): - continue - - # Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. - # Some main models have these keys, likely due to the creator merging in a LoRA. - - has_key_with_lora_prefix = key.startswith( - ( - "lora_te_", - "lora_unet_", - "lora_te1_", - "lora_te2_", - "lora_transformer_", - ) - ) - - # "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT - # LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models. - has_key_with_lora_suffix = key.endswith( - ( - "to_k_lora.up.weight", - "to_q_lora.down.weight", - "lora_A.weight", - "lora_B.weight", - ) - ) - - if has_key_with_lora_prefix or has_key_with_lora_suffix: - return MatchCertainty.MAYBE - - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return { - "base": cls.base_model(mod), - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -844,39 +544,6 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_lora_override = overrides.get("type") is ModelType.LoRA - is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers - - # If both type and format are overridden, skip the heuristic checks and return a perfect score - if is_lora_override and is_diffusers_override: - return MatchCertainty.OVERRIDE - - # Diffusers LoRAs are always directories, never files - if mod.path.is_file(): - return MatchCertainty.NEVER - - is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers - - suffixes = ["bin", "safetensors"] - weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] - has_lora_weight_file = any(wf.exists() for wf in weight_files) - - if is_flux_lora_diffusers and has_lora_weight_file: - return MatchCertainty.EXACT - - if is_flux_lora_diffusers or has_lora_weight_file: - return MatchCertainty.MAYBE - - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return { - "base": cls.base_model(mod), - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -916,56 +583,6 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase): format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint - KEY_PREFIXES: ClassVar = {"encoder.conv_in", "decoder.conv_in"} - - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_vae_override = overrides.get("type") is ModelType.VAE - is_checkpoint_override = overrides.get("format") is ModelFormat.Checkpoint - - if is_vae_override and is_checkpoint_override: - return MatchCertainty.OVERRIDE - - if mod.path.is_dir(): - return MatchCertainty.NEVER - - if mod.has_keys_starting_with(cls.KEY_PREFIXES): - return MatchCertainty.MAYBE - - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - base = cls.get_base_type(mod) - config_path = ( - # For flux, this is a key in invokeai.backend.flux.util.ae_params - # Due to model type and format being the descriminator for model configs this - # is used rather than attempting to support flux with separate model types and format - # If changed in the future, please fix me - "flux" - if base is BaseModelType.Flux - else "stable-diffusion/v1-inference.yaml" - if base is BaseModelType.StableDiffusion1 - else "stable-diffusion/sd_xl_base.yaml" - if base is BaseModelType.StableDiffusionXL - else "stable-diffusion/v2-inference.yaml" - ) - return {"base": base, "config_path": config_path} - - @classmethod - def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType: - # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name - for regexp, basetype in [ - (r"xl", BaseModelType.StableDiffusionXL), - (r"sd2", BaseModelType.StableDiffusion2), - (r"vae", BaseModelType.StableDiffusion1), - (r"FLUX.1-schnell_ae", BaseModelType.Flux), - ]: - if re.search(regexp, mod.path.name, re.IGNORECASE): - return basetype - - raise InvalidModelConfigException("Cannot determine base type") - @classmethod def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name @@ -1012,50 +629,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"} @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_vae_override = overrides.get("type") is ModelType.VAE - is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers - - if is_vae_override and is_diffusers_override: - return MatchCertainty.OVERRIDE - - if mod.path.is_file(): - return MatchCertainty.NEVER - - try: - config = cls.get_config(mod) - class_name = get_class_name_from_config(config) - if class_name in cls.CLASS_NAMES: - return MatchCertainty.EXACT - except Exception: - pass - - return MatchCertainty.NEVER - - @classmethod - def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]: - config_path = mod.path / "config.json" - with open(config_path, "r") as file: - return json.load(file) - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - base = cls.get_base_type(mod) - return {"base": base} - - @classmethod - def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType: - if cls._config_looks_like_sdxl(mod): - return BaseModelType.StableDiffusionXL - elif cls._name_looks_like_sdxl(mod): - return BaseModelType.StableDiffusionXL - else: - # We do not support diffusers VAEs for any other base model at this time... YOLO - return BaseModelType.StableDiffusion1 - - @classmethod - def _config_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool: - config = cls.get_config(mod) + def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool: # Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] @@ -1074,8 +648,8 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): return name @classmethod - def get_base(cls, mod: ModelOnDisk) -> BaseModelType: - if cls._config_looks_like_sdxl(mod): + def get_base(cls, mod: ModelOnDisk, config: dict[str, Any]) -> BaseModelType: + if cls._config_looks_like_sdxl(config): return BaseModelType.StableDiffusionXL elif cls._name_looks_like_sdxl(mod): return BaseModelType.StableDiffusionXL @@ -1113,7 +687,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): if config_class_name not in cls.CLASS_NAMES: raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}") - base = fields.get("base") or cls.get_base(mod) + base = fields.get("base") or cls.get_base(mod, config) return cls(**fields, base=base) @@ -1231,32 +805,6 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): def get_tag(cls) -> Tag: return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}") - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_embedding_override = overrides.get("type") is ModelType.TextualInversion - is_file_override = overrides.get("format") is ModelFormat.EmbeddingFile - - if is_embedding_override and is_file_override: - return MatchCertainty.OVERRIDE - - if mod.path.is_dir(): - return MatchCertainty.NEVER - - if cls.file_looks_like_embedding(mod): - return MatchCertainty.MAYBE - - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - try: - base = cls.get_base(mod) - return {"base": base} - except Exception: - pass - - raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -1290,34 +838,6 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): def get_tag(cls) -> Tag: return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}") - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_embedding_override = overrides.get("type") is ModelType.TextualInversion - is_folder_override = overrides.get("format") is ModelFormat.EmbeddingFolder - - if is_embedding_override and is_folder_override: - return MatchCertainty.OVERRIDE - - if mod.path.is_file(): - return MatchCertainty.NEVER - - for p in mod.path.iterdir(): - if cls.file_looks_like_embedding(mod, p): - return MatchCertainty.MAYBE - - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - try: - for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}: - base = cls.get_base(mod, mod.path / filename) - return {"base": base} - except Exception: - pass - - raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -1367,14 +887,6 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixi prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False - # @classmethod - # def matches(cls, mod: ModelOnDisk) -> bool: - # pass - - # @classmethod - # def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - # pass - class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for main checkpoint models.""" @@ -1425,44 +937,26 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase): format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers base: Literal[BaseModelType.Any] = BaseModelType.Any - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - clip_variant = cls.get_clip_variant_type(mod) - if clip_variant is None: - raise InvalidModelConfigException("Unable to determine CLIP variant type") - - return {"variant": clip_variant} + CLASS_NAMES: ClassVar = { + "CLIPModel", + "CLIPTextModel", + "CLIPTextModelWithProjection", + } @classmethod - def get_clip_variant_type(cls, mod: ModelOnDisk) -> ClipVariantType | None: + def get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None: try: - with open(mod.path / "config.json") as file: - config = json.load(file) - hidden_size = config.get("hidden_size") - match hidden_size: - case 1280: - return ClipVariantType.G - case 768: - return ClipVariantType.L - case _: - return None + hidden_size = config.get("hidden_size") + match hidden_size: + case 1280: + return ClipVariantType.G + case 768: + return ClipVariantType.L + case _: + return None except Exception: return None - @classmethod - def is_clip_text_encoder(cls, mod: ModelOnDisk) -> bool: - try: - with open(mod.path / "config.json", "r") as file: - config = json.load(file) - architectures = config.get("architectures") - return architectures[0] in ( - "CLIPModel", - "CLIPTextModel", - "CLIPTextModelWithProjection", - ) - except Exception: - return False - class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): """Model config for CLIP-G Embeddings.""" @@ -1473,26 +967,6 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): def get_tag(cls) -> Tag: return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}") - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed - is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers - has_clip_variant_override = overrides.get("variant") is ClipVariantType.G - - if is_clip_embed_override and is_diffusers_override and has_clip_variant_override: - return MatchCertainty.OVERRIDE - - if mod.path.is_file(): - return MatchCertainty.NEVER - - is_clip_embed = cls.is_clip_text_encoder(mod) - clip_variant = cls.get_clip_variant_type(mod) - - if is_clip_embed and clip_variant is ClipVariantType.G: - return MatchCertainty.EXACT - - return MatchCertainty.NEVER - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -1518,10 +992,22 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): if mod.path.is_file(): raise NotAMatch(cls, "model path is a file, not a directory") - is_clip_embed = cls.is_clip_text_encoder(mod) - clip_variant = cls.get_clip_variant_type(mod) + try: + config = load_json(mod.path / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load config.json") from e - if not is_clip_embed or clip_variant is not ClipVariantType.G: + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from config") from e + + if config_class_name not in cls.CLASS_NAMES: + raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}") + + clip_variant = cls.get_clip_variant_type(config) + + if clip_variant is not ClipVariantType.G: raise NotAMatch(cls, "model does not match CLIP-G heuristics") return cls(**fields) @@ -1536,26 +1022,6 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): def get_tag(cls) -> Tag: return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}") - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed - is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers - has_clip_variant_override = overrides.get("variant") is ClipVariantType.L - - if is_clip_embed_override and is_diffusers_override and has_clip_variant_override: - return MatchCertainty.OVERRIDE - - if mod.path.is_file(): - return MatchCertainty.NEVER - - is_clip_embed = cls.is_clip_text_encoder(mod) - clip_variant = cls.get_clip_variant_type(mod) - - if is_clip_embed and clip_variant is ClipVariantType.L: - return MatchCertainty.EXACT - - return MatchCertainty.NEVER - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -1581,10 +1047,22 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): if mod.path.is_file(): raise NotAMatch(cls, "model path is a file, not a directory") - is_clip_embed = cls.is_clip_text_encoder(mod) - clip_variant = cls.get_clip_variant_type(mod) + try: + config = load_json(mod.path / "config.json") + except Exception as e: + raise NotAMatch(cls, "unable to load config.json") from e - if not is_clip_embed or clip_variant is not ClipVariantType.L: + try: + config_class_name = get_class_name_from_config(config) + except Exception as e: + raise NotAMatch(cls, "unable to determine class name from config") from e + + if config_class_name not in cls.CLASS_NAMES: + raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}") + + clip_variant = cls.get_clip_variant_type(config) + + if clip_variant is not ClipVariantType.L: raise NotAMatch(cls, "model does not match CLIP-L heuristics") return cls(**fields) @@ -1607,41 +1085,10 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProb class SpandrelImageToImageConfig(ModelConfigBase): """Model config for Spandrel Image to Image models.""" - _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk - base: Literal[BaseModelType.Any] = BaseModelType.Any type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - if not mod.path.is_file(): - return MatchCertainty.NEVER - - try: - # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were - # explored to avoid this: - # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta - # device. Unfortunately, some Spandrel models perform operations during initialization that are not - # supported on meta tensors. - # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. - # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to - # maintain it, and the risk of false positive detections is higher. - SpandrelImageToImageModel.load_from_file(mod.path) - return MatchCertainty.EXACT - except spandrel.UnsupportedModelError: - pass - except Exception as e: - logger.warning( - f"Encountered error while probing to determine if {mod.path} is a Spandrel model. Ignoring. Error: {e}" - ) - - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return {} - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -1704,37 +1151,6 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): base: Literal[BaseModelType.Any] = BaseModelType.Any variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - is_llava_override = overrides.get("type") is ModelType.LlavaOnevision - is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers - - if is_llava_override and is_diffusers_override: - return MatchCertainty.OVERRIDE - - if mod.path.is_file(): - return MatchCertainty.NEVER - - config_path = mod.path / "config.json" - try: - with open(config_path, "r") as file: - config = json.load(file) - except FileNotFoundError: - return MatchCertainty.NEVER - - architectures = config.get("architectures") - if architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration": - return MatchCertainty.EXACT - - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return { - "base": BaseModelType.Any, - "variant": ModelVariantType.Normal, - } - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: type_override = fields.get("type") @@ -1776,33 +1192,16 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase): format: Literal[ModelFormat.Api] = ModelFormat.Api - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - # API models are not stored on disk, so we can't match them. - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - raise NotImplementedError("API models are not parsed from disk.") - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: raise NotAMatch(cls, "API models cannot be built from disk") + class VideoApiModelConfig(VideoConfigBase, ModelConfigBase): """Model config for API-based video models.""" format: Literal[ModelFormat.Api] = ModelFormat.Api - @classmethod - def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: - # API models are not stored on disk, so we can't match them. - return MatchCertainty.NEVER - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - raise NotImplementedError("API models are not parsed from disk.") - @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: raise NotAMatch(cls, "API models cannot be built from disk") @@ -1885,15 +1284,6 @@ AnyModelConfig = Annotated[ AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings] -@dataclass -class ModelClassificationResultSuccess: - model: AnyModelConfig - -@dataclass -class ModelClassificationResultFailure: - error: Exception - -ModelClassificationResult = ModelClassificationResultSuccess | ModelClassificationResultFailure class ModelConfigFactory: @staticmethod