mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(mm): remove unused methods in config.py
This commit is contained in:
@@ -21,7 +21,6 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
"""
|
||||
|
||||
# pyright: reportIncompatibleVariableOverride=false
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -40,7 +39,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import spandrel
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
@@ -88,7 +86,7 @@ class NotAMatch(Exception):
|
||||
reason: The reason why the model did not match.
|
||||
"""
|
||||
|
||||
def __init__(self, config_class: "Type[ModelConfigBase]", reason: str):
|
||||
def __init__(self, config_class: "Type[AnyModelConfig]", reason: str):
|
||||
super().__init__(f"{config_class.__name__} does not match: {reason}")
|
||||
|
||||
|
||||
@@ -104,6 +102,19 @@ def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def validate_overrides(
|
||||
config_class: "Type[AnyModelConfig]", overrides: dict[str, Any], allowed: dict[str, Any]
|
||||
) -> None:
|
||||
for key, value in allowed.items():
|
||||
if key not in overrides:
|
||||
continue
|
||||
if overrides[key] != value:
|
||||
raise NotAMatch(
|
||||
config_class,
|
||||
f"override {key}={overrides[key]} does not match required value {key}={value}",
|
||||
)
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
@@ -139,23 +150,6 @@ class ControlAdapterDefaultSettings(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class MatchSpeed(int, Enum):
|
||||
"""Represents the estimated runtime speed of a config's 'matches' method."""
|
||||
|
||||
FAST = 0
|
||||
MED = 1
|
||||
SLOW = 2
|
||||
|
||||
|
||||
class MatchCertainty(int, Enum):
|
||||
"""Represents the certainty of a config's 'matches' method."""
|
||||
|
||||
NEVER = 0
|
||||
MAYBE = 1
|
||||
EXACT = 2
|
||||
OVERRIDE = 3
|
||||
|
||||
|
||||
class LegacyProbeMixin:
|
||||
"""Mixin for classes using the legacy probe for model classification."""
|
||||
|
||||
@@ -213,7 +207,6 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
|
||||
USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set()
|
||||
USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set()
|
||||
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -228,132 +221,20 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
concrete = {cls for cls in subclasses if not isabstract(cls)}
|
||||
return concrete
|
||||
|
||||
@staticmethod
|
||||
def classify(
|
||||
mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides
|
||||
) -> "AnyModelConfig":
|
||||
"""
|
||||
Returns the best matching ModelConfig instance from a model's file/folder path.
|
||||
Raises InvalidModelConfigException if no valid configuration is found.
|
||||
Created to deprecate ModelProbe.probe
|
||||
"""
|
||||
if isinstance(mod, Path | str):
|
||||
mod = ModelOnDisk(Path(mod), hash_algo)
|
||||
|
||||
candidates = ModelConfigBase.USING_CLASSIFY_API
|
||||
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
|
||||
|
||||
overrides = overrides or {}
|
||||
ModelConfigBase.cast_overrides(**overrides)
|
||||
|
||||
matches: dict[Type[ModelConfigBase], MatchCertainty] = {}
|
||||
|
||||
for config_cls in sorted_by_match_speed:
|
||||
try:
|
||||
score = config_cls.matches(mod, **overrides)
|
||||
|
||||
# A score of 0 means "no match"
|
||||
if score is MatchCertainty.NEVER:
|
||||
continue
|
||||
|
||||
matches[config_cls] = score
|
||||
|
||||
if score is MatchCertainty.EXACT or score is MatchCertainty.OVERRIDE:
|
||||
# Perfect match - skip further checks
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
|
||||
continue
|
||||
|
||||
if matches:
|
||||
# Select the config class with the highest score
|
||||
sorted_by_score = sorted(matches.items(), key=lambda item: item[1].value)
|
||||
# Check if there are multiple classes with the same top score
|
||||
top_score = sorted_by_score[-1][1]
|
||||
top_classes = [cls for cls, score in sorted_by_score if score is top_score]
|
||||
if len(top_classes) > 1:
|
||||
logger.warning(
|
||||
f"Multiple model config classes matched with the same top score ({top_score}) for model {mod.name}: {[cls.__name__ for cls in top_classes]}. Using {top_classes[0].__name__}."
|
||||
)
|
||||
config_cls = top_classes[0]
|
||||
# Finally, create the config instance
|
||||
logger.info(f"Model {mod.name} classified as {config_cls.__name__} with score {top_score.name}")
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
if app_config.allow_unknown_models:
|
||||
try:
|
||||
return UnknownModelConfig.from_model_on_disk(mod, **overrides)
|
||||
except Exception:
|
||||
# Fall through to raising the exception below
|
||||
pass
|
||||
|
||||
raise InvalidModelConfigException("Unable to determine model type")
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
type = cls.model_fields["type"].default.value
|
||||
format = cls.model_fields["format"].default.value
|
||||
return Tag(f"{type}.{format}")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
"""Returns a dictionary with the fields needed to construct the model.
|
||||
Raises InvalidModelConfigException if the model is invalid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
"""Performs a quick check to determine if the config matches the model.
|
||||
Returns a MatchCertainty score."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
"""Performs a quick check to determine if the config matches the model.
|
||||
Returns a MatchCertainty score."""
|
||||
"""Given the model on disk and any overrides, return an instance of this config class.
|
||||
|
||||
Implementations should raise NotAMatch if the model does not match this config class."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def cast_overrides(**overrides):
|
||||
"""Casts user overrides from str to Enum"""
|
||||
if "type" in overrides:
|
||||
overrides["type"] = ModelType(overrides["type"])
|
||||
|
||||
if "format" in overrides:
|
||||
overrides["format"] = ModelFormat(overrides["format"])
|
||||
|
||||
if "base" in overrides:
|
||||
overrides["base"] = BaseModelType(overrides["base"])
|
||||
|
||||
if "source_type" in overrides:
|
||||
overrides["source_type"] = ModelSourceType(overrides["source_type"])
|
||||
|
||||
if "variant" in overrides:
|
||||
overrides["variant"] = variant_type_adapter.validate_strings(overrides["variant"])
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk_2(cls, mod: ModelOnDisk, **overrides):
|
||||
"""Creates an instance of this config or raises InvalidModelConfigException."""
|
||||
fields = cls.parse(mod)
|
||||
cls.cast_overrides(**overrides)
|
||||
fields.update(overrides)
|
||||
|
||||
fields["path"] = mod.path.as_posix()
|
||||
fields["source"] = fields.get("source") or fields["path"]
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["name"] = fields.get("name") or mod.name
|
||||
fields["hash"] = fields.get("hash") or mod.hash()
|
||||
fields["key"] = fields.get("key") or uuid_string()
|
||||
fields["description"] = fields.get("description")
|
||||
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
|
||||
fields["file_size"] = fields.get("file_size") or mod.size()
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
class UnknownModelConfig(ModelConfigBase):
|
||||
base: Literal[BaseModelType.Unknown] = BaseModelType.Unknown
|
||||
@@ -361,12 +242,8 @@ class UnknownModelConfig(ModelConfigBase):
|
||||
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {}
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
raise NotAMatch(cls, "unknown model config cannot match any model")
|
||||
|
||||
|
||||
class CheckpointConfigBase(ABC, BaseModel):
|
||||
@@ -441,16 +318,6 @@ class T5EncoderConfigBase(ABC, BaseModel):
|
||||
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
||||
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
path = mod.path / "text_encoder_2" / "config.json"
|
||||
with open(path, "r") as file:
|
||||
return json.load(file)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def load_json(path: Path) -> dict[str, Any]:
|
||||
with open(path, "r") as file:
|
||||
@@ -460,35 +327,6 @@ def load_json(path: Path) -> dict[str, Any]:
|
||||
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_t5_type_override = overrides.get("type") is ModelType.T5Encoder
|
||||
is_t5_format_override = overrides.get("format") is ModelFormat.T5Encoder
|
||||
|
||||
if is_t5_type_override and is_t5_format_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
model_dir = mod.path / "text_encoder_2"
|
||||
|
||||
if not model_dir.exists():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
try:
|
||||
config = cls.get_config(mod)
|
||||
|
||||
is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel"
|
||||
is_t5_format = (model_dir / "model.safetensors.index.json").exists()
|
||||
|
||||
if is_t5_encoder_model and is_t5_format:
|
||||
return MatchCertainty.EXACT
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -532,44 +370,6 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_t5_type_override = overrides.get("type") is ModelType.T5Encoder
|
||||
is_bnb_format_override = overrides.get("format") is ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
if is_t5_type_override and is_bnb_format_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
model_dir = mod.path / "text_encoder_2"
|
||||
|
||||
if not model_dir.exists():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
try:
|
||||
config = cls.get_config(mod)
|
||||
|
||||
is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel"
|
||||
|
||||
# Heuristic: look for the quantization in the name
|
||||
files = model_dir.glob("*.safetensors")
|
||||
filename_looks_like_bnb = any(x for x in files if "llm_int8" in x.as_posix())
|
||||
|
||||
if is_t5_encoder_model and filename_looks_like_bnb:
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
# Heuristic: Look for the presence of "SCB" in state dict keys (typically a suffix)
|
||||
has_scb_key = mod.has_keys_ending_with("SCB")
|
||||
|
||||
if is_t5_encoder_model and has_scb_key:
|
||||
return MatchCertainty.EXACT
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -651,7 +451,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType
|
||||
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
metadata = mod.metadata()
|
||||
architecture = metadata["modelspec.architecture"]
|
||||
|
||||
@@ -662,112 +462,12 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
else:
|
||||
raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_lora_override = overrides.get("type") is ModelType.LoRA
|
||||
is_omi_override = overrides.get("format") is ModelFormat.OMI
|
||||
|
||||
# If both type and format are overridden, skip the heuristic checks
|
||||
if is_lora_override and is_omi_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
# OMI LoRAs are always files, never directories
|
||||
if mod.path.is_dir():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
# Avoid false positive match against ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
metadata = mod.metadata()
|
||||
is_omi_lora_heuristic = (
|
||||
bool(metadata.get("modelspec.sai_model_spec"))
|
||||
and metadata.get("ot_branch") == "omi_format"
|
||||
and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
|
||||
)
|
||||
|
||||
if is_omi_lora_heuristic:
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
metadata = mod.metadata()
|
||||
architecture = metadata["modelspec.architecture"]
|
||||
|
||||
if architecture == stable_diffusion_xl_1_lora:
|
||||
base = BaseModelType.StableDiffusionXL
|
||||
elif architecture == flux_dev_1_lora:
|
||||
base = BaseModelType.Flux
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}")
|
||||
|
||||
return {"base": base}
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_lora_override = overrides.get("type") is ModelType.LoRA
|
||||
is_omi_override = overrides.get("format") is ModelFormat.LyCORIS
|
||||
|
||||
# If both type and format are overridden, skip the heuristic checks and return a perfect score
|
||||
if is_lora_override and is_omi_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
# LyCORIS LoRAs are always files, never directories
|
||||
if mod.path.is_dir():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
# Avoid false positive match against ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
for key in state_dict.keys():
|
||||
if isinstance(key, int):
|
||||
continue
|
||||
|
||||
# Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
|
||||
# Some main models have these keys, likely due to the creator merging in a LoRA.
|
||||
|
||||
has_key_with_lora_prefix = key.startswith(
|
||||
(
|
||||
"lora_te_",
|
||||
"lora_unet_",
|
||||
"lora_te1_",
|
||||
"lora_te2_",
|
||||
"lora_transformer_",
|
||||
)
|
||||
)
|
||||
|
||||
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
has_key_with_lora_suffix = key.endswith(
|
||||
(
|
||||
"to_k_lora.up.weight",
|
||||
"to_q_lora.down.weight",
|
||||
"lora_A.weight",
|
||||
"lora_B.weight",
|
||||
)
|
||||
)
|
||||
|
||||
if has_key_with_lora_prefix or has_key_with_lora_suffix:
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": cls.base_model(mod),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -844,39 +544,6 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_lora_override = overrides.get("type") is ModelType.LoRA
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
|
||||
# If both type and format are overridden, skip the heuristic checks and return a perfect score
|
||||
if is_lora_override and is_diffusers_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
# Diffusers LoRAs are always directories, never files
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
|
||||
|
||||
suffixes = ["bin", "safetensors"]
|
||||
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
|
||||
has_lora_weight_file = any(wf.exists() for wf in weight_files)
|
||||
|
||||
if is_flux_lora_diffusers and has_lora_weight_file:
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
if is_flux_lora_diffusers or has_lora_weight_file:
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": cls.base_model(mod),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -916,56 +583,6 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
KEY_PREFIXES: ClassVar = {"encoder.conv_in", "decoder.conv_in"}
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_vae_override = overrides.get("type") is ModelType.VAE
|
||||
is_checkpoint_override = overrides.get("format") is ModelFormat.Checkpoint
|
||||
|
||||
if is_vae_override and is_checkpoint_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_dir():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
if mod.has_keys_starting_with(cls.KEY_PREFIXES):
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
base = cls.get_base_type(mod)
|
||||
config_path = (
|
||||
# For flux, this is a key in invokeai.backend.flux.util.ae_params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
"flux"
|
||||
if base is BaseModelType.Flux
|
||||
else "stable-diffusion/v1-inference.yaml"
|
||||
if base is BaseModelType.StableDiffusion1
|
||||
else "stable-diffusion/sd_xl_base.yaml"
|
||||
if base is BaseModelType.StableDiffusionXL
|
||||
else "stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
return {"base": base, "config_path": config_path}
|
||||
|
||||
@classmethod
|
||||
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
|
||||
for regexp, basetype in [
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
|
||||
]:
|
||||
if re.search(regexp, mod.path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
@classmethod
|
||||
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
|
||||
@@ -1012,50 +629,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
||||
CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"}
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_vae_override = overrides.get("type") is ModelType.VAE
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
|
||||
if is_vae_override and is_diffusers_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
try:
|
||||
config = cls.get_config(mod)
|
||||
class_name = get_class_name_from_config(config)
|
||||
if class_name in cls.CLASS_NAMES:
|
||||
return MatchCertainty.EXACT
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
config_path = mod.path / "config.json"
|
||||
with open(config_path, "r") as file:
|
||||
return json.load(file)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
base = cls.get_base_type(mod)
|
||||
return {"base": base}
|
||||
|
||||
@classmethod
|
||||
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
if cls._config_looks_like_sdxl(mod):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif cls._name_looks_like_sdxl(mod):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
# We do not support diffusers VAEs for any other base model at this time... YOLO
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
@classmethod
|
||||
def _config_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
|
||||
config = cls.get_config(mod)
|
||||
def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool:
|
||||
# Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
@@ -1074,8 +648,8 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def get_base(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
if cls._config_looks_like_sdxl(mod):
|
||||
def get_base(cls, mod: ModelOnDisk, config: dict[str, Any]) -> BaseModelType:
|
||||
if cls._config_looks_like_sdxl(config):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif cls._name_looks_like_sdxl(mod):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
@@ -1113,7 +687,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
||||
if config_class_name not in cls.CLASS_NAMES:
|
||||
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
|
||||
|
||||
base = fields.get("base") or cls.get_base(mod)
|
||||
base = fields.get("base") or cls.get_base(mod, config)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
|
||||
@@ -1231,32 +805,6 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
|
||||
is_file_override = overrides.get("format") is ModelFormat.EmbeddingFile
|
||||
|
||||
if is_embedding_override and is_file_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_dir():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
if cls.file_looks_like_embedding(mod):
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
try:
|
||||
base = cls.get_base(mod)
|
||||
return {"base": base}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -1290,34 +838,6 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
|
||||
is_folder_override = overrides.get("format") is ModelFormat.EmbeddingFolder
|
||||
|
||||
if is_embedding_override and is_folder_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
for p in mod.path.iterdir():
|
||||
if cls.file_looks_like_embedding(mod, p):
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
try:
|
||||
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
|
||||
base = cls.get_base(mod, mod.path / filename)
|
||||
return {"base": base}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -1367,14 +887,6 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixi
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
# @classmethod
|
||||
# def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
# pass
|
||||
|
||||
# @classmethod
|
||||
# def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
# pass
|
||||
|
||||
|
||||
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
@@ -1425,44 +937,26 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
clip_variant = cls.get_clip_variant_type(mod)
|
||||
if clip_variant is None:
|
||||
raise InvalidModelConfigException("Unable to determine CLIP variant type")
|
||||
|
||||
return {"variant": clip_variant}
|
||||
CLASS_NAMES: ClassVar = {
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
"CLIPTextModelWithProjection",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_clip_variant_type(cls, mod: ModelOnDisk) -> ClipVariantType | None:
|
||||
def get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None:
|
||||
try:
|
||||
with open(mod.path / "config.json") as file:
|
||||
config = json.load(file)
|
||||
hidden_size = config.get("hidden_size")
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return None
|
||||
hidden_size = config.get("hidden_size")
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_clip_text_encoder(cls, mod: ModelOnDisk) -> bool:
|
||||
try:
|
||||
with open(mod.path / "config.json", "r") as file:
|
||||
config = json.load(file)
|
||||
architectures = config.get("architectures")
|
||||
return architectures[0] in (
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
"CLIPTextModelWithProjection",
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
"""Model config for CLIP-G Embeddings."""
|
||||
@@ -1473,26 +967,6 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
has_clip_variant_override = overrides.get("variant") is ClipVariantType.G
|
||||
|
||||
if is_clip_embed_override and is_diffusers_override and has_clip_variant_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
is_clip_embed = cls.is_clip_text_encoder(mod)
|
||||
clip_variant = cls.get_clip_variant_type(mod)
|
||||
|
||||
if is_clip_embed and clip_variant is ClipVariantType.G:
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -1518,10 +992,22 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
is_clip_embed = cls.is_clip_text_encoder(mod)
|
||||
clip_variant = cls.get_clip_variant_type(mod)
|
||||
try:
|
||||
config = load_json(mod.path / "config.json")
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to load config.json") from e
|
||||
|
||||
if not is_clip_embed or clip_variant is not ClipVariantType.G:
|
||||
try:
|
||||
config_class_name = get_class_name_from_config(config)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to determine class name from config") from e
|
||||
|
||||
if config_class_name not in cls.CLASS_NAMES:
|
||||
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
|
||||
|
||||
clip_variant = cls.get_clip_variant_type(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.G:
|
||||
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
|
||||
|
||||
return cls(**fields)
|
||||
@@ -1536,26 +1022,6 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
has_clip_variant_override = overrides.get("variant") is ClipVariantType.L
|
||||
|
||||
if is_clip_embed_override and is_diffusers_override and has_clip_variant_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
is_clip_embed = cls.is_clip_text_encoder(mod)
|
||||
clip_variant = cls.get_clip_variant_type(mod)
|
||||
|
||||
if is_clip_embed and clip_variant is ClipVariantType.L:
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -1581,10 +1047,22 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
is_clip_embed = cls.is_clip_text_encoder(mod)
|
||||
clip_variant = cls.get_clip_variant_type(mod)
|
||||
try:
|
||||
config = load_json(mod.path / "config.json")
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to load config.json") from e
|
||||
|
||||
if not is_clip_embed or clip_variant is not ClipVariantType.L:
|
||||
try:
|
||||
config_class_name = get_class_name_from_config(config)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to determine class name from config") from e
|
||||
|
||||
if config_class_name not in cls.CLASS_NAMES:
|
||||
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
|
||||
|
||||
clip_variant = cls.get_clip_variant_type(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.L:
|
||||
raise NotAMatch(cls, "model does not match CLIP-L heuristics")
|
||||
|
||||
return cls(**fields)
|
||||
@@ -1607,41 +1085,10 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProb
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
|
||||
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
|
||||
|
||||
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
||||
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
if not mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
# explored to avoid this:
|
||||
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
|
||||
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
|
||||
# supported on meta tensors.
|
||||
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
|
||||
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
||||
# maintain it, and the risk of false positive detections is higher.
|
||||
SpandrelImageToImageModel.load_from_file(mod.path)
|
||||
return MatchCertainty.EXACT
|
||||
except spandrel.UnsupportedModelError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Encountered error while probing to determine if {mod.path} is a Spandrel model. Ignoring. Error: {e}"
|
||||
)
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -1704,37 +1151,6 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
||||
variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_llava_override = overrides.get("type") is ModelType.LlavaOnevision
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
|
||||
if is_llava_override and is_diffusers_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
try:
|
||||
with open(config_path, "r") as file:
|
||||
config = json.load(file)
|
||||
except FileNotFoundError:
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
architectures = config.get("architectures")
|
||||
if architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration":
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": BaseModelType.Any,
|
||||
"variant": ModelVariantType.Normal,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
@@ -1776,33 +1192,16 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.Api] = ModelFormat.Api
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
# API models are not stored on disk, so we can't match them.
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
raise NotImplementedError("API models are not parsed from disk.")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
raise NotAMatch(cls, "API models cannot be built from disk")
|
||||
|
||||
|
||||
class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
|
||||
"""Model config for API-based video models."""
|
||||
|
||||
format: Literal[ModelFormat.Api] = ModelFormat.Api
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
# API models are not stored on disk, so we can't match them.
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
raise NotImplementedError("API models are not parsed from disk.")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
raise NotAMatch(cls, "API models cannot be built from disk")
|
||||
@@ -1885,15 +1284,6 @@ AnyModelConfig = Annotated[
|
||||
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
|
||||
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]
|
||||
|
||||
@dataclass
|
||||
class ModelClassificationResultSuccess:
|
||||
model: AnyModelConfig
|
||||
|
||||
@dataclass
|
||||
class ModelClassificationResultFailure:
|
||||
error: Exception
|
||||
|
||||
ModelClassificationResult = ModelClassificationResultSuccess | ModelClassificationResultFailure
|
||||
|
||||
class ModelConfigFactory:
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user