refactor(mm): remove unused methods in config.py

This commit is contained in:
psychedelicious
2025-09-24 16:49:59 +10:00
parent 8399de9c25
commit fd47da6842

View File

@@ -21,7 +21,6 @@ Validation errors will raise an InvalidModelConfigException error.
"""
# pyright: reportIncompatibleVariableOverride=false
from dataclasses import dataclass
import json
import logging
import re
@@ -40,7 +39,6 @@ from typing import (
Union,
)
import spandrel
import torch
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError
from typing_extensions import Annotated, Any, Dict
@@ -88,7 +86,7 @@ class NotAMatch(Exception):
reason: The reason why the model did not match.
"""
def __init__(self, config_class: "Type[ModelConfigBase]", reason: str):
def __init__(self, config_class: "Type[AnyModelConfig]", reason: str):
super().__init__(f"{config_class.__name__} does not match: {reason}")
@@ -104,6 +102,19 @@ def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]:
return None
def validate_overrides(
config_class: "Type[AnyModelConfig]", overrides: dict[str, Any], allowed: dict[str, Any]
) -> None:
for key, value in allowed.items():
if key not in overrides:
continue
if overrides[key] != value:
raise NotAMatch(
config_class,
f"override {key}={overrides[key]} does not match required value {key}={value}",
)
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
@@ -139,23 +150,6 @@ class ControlAdapterDefaultSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""
FAST = 0
MED = 1
SLOW = 2
class MatchCertainty(int, Enum):
"""Represents the certainty of a config's 'matches' method."""
NEVER = 0
MAYBE = 1
EXACT = 2
OVERRIDE = 3
class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""
@@ -213,7 +207,6 @@ class ModelConfigBase(ABC, BaseModel):
USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set()
USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
@@ -228,132 +221,20 @@ class ModelConfigBase(ABC, BaseModel):
concrete = {cls for cls in subclasses if not isabstract(cls)}
return concrete
@staticmethod
def classify(
mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides
) -> "AnyModelConfig":
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
if isinstance(mod, Path | str):
mod = ModelOnDisk(Path(mod), hash_algo)
candidates = ModelConfigBase.USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
overrides = overrides or {}
ModelConfigBase.cast_overrides(**overrides)
matches: dict[Type[ModelConfigBase], MatchCertainty] = {}
for config_cls in sorted_by_match_speed:
try:
score = config_cls.matches(mod, **overrides)
# A score of 0 means "no match"
if score is MatchCertainty.NEVER:
continue
matches[config_cls] = score
if score is MatchCertainty.EXACT or score is MatchCertainty.OVERRIDE:
# Perfect match - skip further checks
break
except Exception as e:
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
continue
if matches:
# Select the config class with the highest score
sorted_by_score = sorted(matches.items(), key=lambda item: item[1].value)
# Check if there are multiple classes with the same top score
top_score = sorted_by_score[-1][1]
top_classes = [cls for cls, score in sorted_by_score if score is top_score]
if len(top_classes) > 1:
logger.warning(
f"Multiple model config classes matched with the same top score ({top_score}) for model {mod.name}: {[cls.__name__ for cls in top_classes]}. Using {top_classes[0].__name__}."
)
config_cls = top_classes[0]
# Finally, create the config instance
logger.info(f"Model {mod.name} classified as {config_cls.__name__} with score {top_score.name}")
return config_cls.from_model_on_disk(mod, **overrides)
if app_config.allow_unknown_models:
try:
return UnknownModelConfig.from_model_on_disk(mod, **overrides)
except Exception:
# Fall through to raising the exception below
pass
raise InvalidModelConfigException("Unable to determine model type")
@classmethod
def get_tag(cls) -> Tag:
type = cls.model_fields["type"].default.value
format = cls.model_fields["format"].default.value
return Tag(f"{type}.{format}")
@classmethod
@abstractmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
"""Returns a dictionary with the fields needed to construct the model.
Raises InvalidModelConfigException if the model is invalid.
"""
pass
@classmethod
@abstractmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
"""Performs a quick check to determine if the config matches the model.
Returns a MatchCertainty score."""
pass
@classmethod
@abstractmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
"""Performs a quick check to determine if the config matches the model.
Returns a MatchCertainty score."""
"""Given the model on disk and any overrides, return an instance of this config class.
Implementations should raise NotAMatch if the model does not match this config class."""
pass
@staticmethod
def cast_overrides(**overrides):
"""Casts user overrides from str to Enum"""
if "type" in overrides:
overrides["type"] = ModelType(overrides["type"])
if "format" in overrides:
overrides["format"] = ModelFormat(overrides["format"])
if "base" in overrides:
overrides["base"] = BaseModelType(overrides["base"])
if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])
if "variant" in overrides:
overrides["variant"] = variant_type_adapter.validate_strings(overrides["variant"])
@classmethod
def from_model_on_disk_2(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
fields = cls.parse(mod)
cls.cast_overrides(**overrides)
fields.update(overrides)
fields["path"] = mod.path.as_posix()
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = fields.get("name") or mod.name
fields["hash"] = fields.get("hash") or mod.hash()
fields["key"] = fields.get("key") or uuid_string()
fields["description"] = fields.get("description")
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
fields["file_size"] = fields.get("file_size") or mod.size()
return cls(**fields)
class UnknownModelConfig(ModelConfigBase):
base: Literal[BaseModelType.Unknown] = BaseModelType.Unknown
@@ -361,12 +242,8 @@ class UnknownModelConfig(ModelConfigBase):
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {}
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
raise NotAMatch(cls, "unknown model config cannot match any model")
class CheckpointConfigBase(ABC, BaseModel):
@@ -441,16 +318,6 @@ class T5EncoderConfigBase(ABC, BaseModel):
base: Literal[BaseModelType.Any] = BaseModelType.Any
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
@classmethod
def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]:
path = mod.path / "text_encoder_2" / "config.json"
with open(path, "r") as file:
return json.load(file)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {}
def load_json(path: Path) -> dict[str, Any]:
with open(path, "r") as file:
@@ -460,35 +327,6 @@ def load_json(path: Path) -> dict[str, Any]:
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_t5_type_override = overrides.get("type") is ModelType.T5Encoder
is_t5_format_override = overrides.get("format") is ModelFormat.T5Encoder
if is_t5_type_override and is_t5_format_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
model_dir = mod.path / "text_encoder_2"
if not model_dir.exists():
return MatchCertainty.NEVER
try:
config = cls.get_config(mod)
is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel"
is_t5_format = (model_dir / "model.safetensors.index.json").exists()
if is_t5_encoder_model and is_t5_format:
return MatchCertainty.EXACT
except Exception:
pass
return MatchCertainty.NEVER
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -532,44 +370,6 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_t5_type_override = overrides.get("type") is ModelType.T5Encoder
is_bnb_format_override = overrides.get("format") is ModelFormat.BnbQuantizedLlmInt8b
if is_t5_type_override and is_bnb_format_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
model_dir = mod.path / "text_encoder_2"
if not model_dir.exists():
return MatchCertainty.NEVER
try:
config = cls.get_config(mod)
is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel"
# Heuristic: look for the quantization in the name
files = model_dir.glob("*.safetensors")
filename_looks_like_bnb = any(x for x in files if "llm_int8" in x.as_posix())
if is_t5_encoder_model and filename_looks_like_bnb:
return MatchCertainty.EXACT
# Heuristic: Look for the presence of "SCB" in state dict keys (typically a suffix)
has_scb_key = mod.has_keys_ending_with("SCB")
if is_t5_encoder_model and has_scb_key:
return MatchCertainty.EXACT
except Exception:
pass
return MatchCertainty.NEVER
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -651,7 +451,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
return cls(**fields, base=base)
@classmethod
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
@@ -662,112 +462,12 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
else:
raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_lora_override = overrides.get("type") is ModelType.LoRA
is_omi_override = overrides.get("format") is ModelFormat.OMI
# If both type and format are overridden, skip the heuristic checks
if is_lora_override and is_omi_override:
return MatchCertainty.OVERRIDE
# OMI LoRAs are always files, never directories
if mod.path.is_dir():
return MatchCertainty.NEVER
# Avoid false positive match against ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
return MatchCertainty.NEVER
metadata = mod.metadata()
is_omi_lora_heuristic = (
bool(metadata.get("modelspec.sai_model_spec"))
and metadata.get("ot_branch") == "omi_format"
and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
)
if is_omi_lora_heuristic:
return MatchCertainty.EXACT
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
if architecture == stable_diffusion_xl_1_lora:
base = BaseModelType.StableDiffusionXL
elif architecture == flux_dev_1_lora:
base = BaseModelType.Flux
else:
raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}")
return {"base": base}
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_lora_override = overrides.get("type") is ModelType.LoRA
is_omi_override = overrides.get("format") is ModelFormat.LyCORIS
# If both type and format are overridden, skip the heuristic checks and return a perfect score
if is_lora_override and is_omi_override:
return MatchCertainty.OVERRIDE
# LyCORIS LoRAs are always files, never directories
if mod.path.is_dir():
return MatchCertainty.NEVER
# Avoid false positive match against ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
return MatchCertainty.NEVER
state_dict = mod.load_state_dict()
for key in state_dict.keys():
if isinstance(key, int):
continue
# Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = key.startswith(
(
"lora_te_",
"lora_unet_",
"lora_te1_",
"lora_te2_",
"lora_transformer_",
)
)
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
has_key_with_lora_suffix = key.endswith(
(
"to_k_lora.up.weight",
"to_q_lora.down.weight",
"lora_A.weight",
"lora_B.weight",
)
)
if has_key_with_lora_prefix or has_key_with_lora_suffix:
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -844,39 +544,6 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_lora_override = overrides.get("type") is ModelType.LoRA
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
# If both type and format are overridden, skip the heuristic checks and return a perfect score
if is_lora_override and is_diffusers_override:
return MatchCertainty.OVERRIDE
# Diffusers LoRAs are always directories, never files
if mod.path.is_file():
return MatchCertainty.NEVER
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
has_lora_weight_file = any(wf.exists() for wf in weight_files)
if is_flux_lora_diffusers and has_lora_weight_file:
return MatchCertainty.EXACT
if is_flux_lora_diffusers or has_lora_weight_file:
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -916,56 +583,6 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
KEY_PREFIXES: ClassVar = {"encoder.conv_in", "decoder.conv_in"}
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_vae_override = overrides.get("type") is ModelType.VAE
is_checkpoint_override = overrides.get("format") is ModelFormat.Checkpoint
if is_vae_override and is_checkpoint_override:
return MatchCertainty.OVERRIDE
if mod.path.is_dir():
return MatchCertainty.NEVER
if mod.has_keys_starting_with(cls.KEY_PREFIXES):
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
base = cls.get_base_type(mod)
config_path = (
# For flux, this is a key in invokeai.backend.flux.util.ae_params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
"flux"
if base is BaseModelType.Flux
else "stable-diffusion/v1-inference.yaml"
if base is BaseModelType.StableDiffusion1
else "stable-diffusion/sd_xl_base.yaml"
if base is BaseModelType.StableDiffusionXL
else "stable-diffusion/v2-inference.yaml"
)
return {"base": base, "config_path": config_path}
@classmethod
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
]:
if re.search(regexp, mod.path.name, re.IGNORECASE):
return basetype
raise InvalidModelConfigException("Cannot determine base type")
@classmethod
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
@@ -1012,50 +629,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"}
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_vae_override = overrides.get("type") is ModelType.VAE
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
if is_vae_override and is_diffusers_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
try:
config = cls.get_config(mod)
class_name = get_class_name_from_config(config)
if class_name in cls.CLASS_NAMES:
return MatchCertainty.EXACT
except Exception:
pass
return MatchCertainty.NEVER
@classmethod
def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]:
config_path = mod.path / "config.json"
with open(config_path, "r") as file:
return json.load(file)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
base = cls.get_base_type(mod)
return {"base": base}
@classmethod
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
if cls._config_looks_like_sdxl(mod):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
return BaseModelType.StableDiffusionXL
else:
# We do not support diffusers VAEs for any other base model at this time... YOLO
return BaseModelType.StableDiffusion1
@classmethod
def _config_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
config = cls.get_config(mod)
def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool:
# Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
@@ -1074,8 +648,8 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
return name
@classmethod
def get_base(cls, mod: ModelOnDisk) -> BaseModelType:
if cls._config_looks_like_sdxl(mod):
def get_base(cls, mod: ModelOnDisk, config: dict[str, Any]) -> BaseModelType:
if cls._config_looks_like_sdxl(config):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
return BaseModelType.StableDiffusionXL
@@ -1113,7 +687,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
if config_class_name not in cls.CLASS_NAMES:
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
base = fields.get("base") or cls.get_base(mod)
base = fields.get("base") or cls.get_base(mod, config)
return cls(**fields, base=base)
@@ -1231,32 +805,6 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
is_file_override = overrides.get("format") is ModelFormat.EmbeddingFile
if is_embedding_override and is_file_override:
return MatchCertainty.OVERRIDE
if mod.path.is_dir():
return MatchCertainty.NEVER
if cls.file_looks_like_embedding(mod):
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
try:
base = cls.get_base(mod)
return {"base": base}
except Exception:
pass
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -1290,34 +838,6 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
is_folder_override = overrides.get("format") is ModelFormat.EmbeddingFolder
if is_embedding_override and is_folder_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
for p in mod.path.iterdir():
if cls.file_looks_like_embedding(mod, p):
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
try:
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
base = cls.get_base(mod, mod.path / filename)
return {"base": base}
except Exception:
pass
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -1367,14 +887,6 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixi
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
# @classmethod
# def matches(cls, mod: ModelOnDisk) -> bool:
# pass
# @classmethod
# def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
# pass
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
@@ -1425,44 +937,26 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: Literal[BaseModelType.Any] = BaseModelType.Any
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
clip_variant = cls.get_clip_variant_type(mod)
if clip_variant is None:
raise InvalidModelConfigException("Unable to determine CLIP variant type")
return {"variant": clip_variant}
CLASS_NAMES: ClassVar = {
"CLIPModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
}
@classmethod
def get_clip_variant_type(cls, mod: ModelOnDisk) -> ClipVariantType | None:
def get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None:
try:
with open(mod.path / "config.json") as file:
config = json.load(file)
hidden_size = config.get("hidden_size")
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None
hidden_size = config.get("hidden_size")
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None
except Exception:
return None
@classmethod
def is_clip_text_encoder(cls, mod: ModelOnDisk) -> bool:
try:
with open(mod.path / "config.json", "r") as file:
config = json.load(file)
architectures = config.get("architectures")
return architectures[0] in (
"CLIPModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
)
except Exception:
return False
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-G Embeddings."""
@@ -1473,26 +967,6 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
has_clip_variant_override = overrides.get("variant") is ClipVariantType.G
if is_clip_embed_override and is_diffusers_override and has_clip_variant_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
is_clip_embed = cls.is_clip_text_encoder(mod)
clip_variant = cls.get_clip_variant_type(mod)
if is_clip_embed and clip_variant is ClipVariantType.G:
return MatchCertainty.EXACT
return MatchCertainty.NEVER
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -1518,10 +992,22 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
is_clip_embed = cls.is_clip_text_encoder(mod)
clip_variant = cls.get_clip_variant_type(mod)
try:
config = load_json(mod.path / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load config.json") from e
if not is_clip_embed or clip_variant is not ClipVariantType.G:
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from config") from e
if config_class_name not in cls.CLASS_NAMES:
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
clip_variant = cls.get_clip_variant_type(config)
if clip_variant is not ClipVariantType.G:
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
return cls(**fields)
@@ -1536,26 +1022,6 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
has_clip_variant_override = overrides.get("variant") is ClipVariantType.L
if is_clip_embed_override and is_diffusers_override and has_clip_variant_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
is_clip_embed = cls.is_clip_text_encoder(mod)
clip_variant = cls.get_clip_variant_type(mod)
if is_clip_embed and clip_variant is ClipVariantType.L:
return MatchCertainty.EXACT
return MatchCertainty.NEVER
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -1581,10 +1047,22 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
is_clip_embed = cls.is_clip_text_encoder(mod)
clip_variant = cls.get_clip_variant_type(mod)
try:
config = load_json(mod.path / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load config.json") from e
if not is_clip_embed or clip_variant is not ClipVariantType.L:
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from config") from e
if config_class_name not in cls.CLASS_NAMES:
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
clip_variant = cls.get_clip_variant_type(config)
if clip_variant is not ClipVariantType.L:
raise NotAMatch(cls, "model does not match CLIP-L heuristics")
return cls(**fields)
@@ -1607,41 +1085,10 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProb
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
base: Literal[BaseModelType.Any] = BaseModelType.Any
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
if not mod.path.is_file():
return MatchCertainty.NEVER
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(mod.path)
return MatchCertainty.EXACT
except spandrel.UnsupportedModelError:
pass
except Exception as e:
logger.warning(
f"Encountered error while probing to determine if {mod.path} is a Spandrel model. Ignoring. Error: {e}"
)
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -1704,37 +1151,6 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
base: Literal[BaseModelType.Any] = BaseModelType.Any
variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_llava_override = overrides.get("type") is ModelType.LlavaOnevision
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
if is_llava_override and is_diffusers_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
config_path = mod.path / "config.json"
try:
with open(config_path, "r") as file:
config = json.load(file)
except FileNotFoundError:
return MatchCertainty.NEVER
architectures = config.get("architectures")
if architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration":
return MatchCertainty.EXACT
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": BaseModelType.Any,
"variant": ModelVariantType.Normal,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
@@ -1776,33 +1192,16 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase):
format: Literal[ModelFormat.Api] = ModelFormat.Api
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
# API models are not stored on disk, so we can't match them.
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
raise NotImplementedError("API models are not parsed from disk.")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
raise NotAMatch(cls, "API models cannot be built from disk")
class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
"""Model config for API-based video models."""
format: Literal[ModelFormat.Api] = ModelFormat.Api
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
# API models are not stored on disk, so we can't match them.
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
raise NotImplementedError("API models are not parsed from disk.")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
raise NotAMatch(cls, "API models cannot be built from disk")
@@ -1885,15 +1284,6 @@ AnyModelConfig = Annotated[
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]
@dataclass
class ModelClassificationResultSuccess:
model: AnyModelConfig
@dataclass
class ModelClassificationResultFailure:
error: Exception
ModelClassificationResult = ModelClassificationResultSuccess | ModelClassificationResultFailure
class ModelConfigFactory:
@staticmethod