fix(mm): tag generation & scattered probe fixes

This commit is contained in:
psychedelicious
2025-10-01 19:47:57 +10:00
parent 1e1c8b988b
commit 165f57286a
2 changed files with 119 additions and 64 deletions

View File

@@ -146,19 +146,19 @@ class FieldValidator:
return FieldValidator.get_validator(model, field_name).validate_python(value)
def has_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
"""Returns true if the state dict has all of the specified keys."""
def has_any_keys(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
"""Returns true if the state dict has any of the specified keys."""
_keys = {keys} if isinstance(keys, str) else keys
return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
return any(key in state_dict for key in _keys)
def has_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
def has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
"""Returns true if the state dict has any keys starting with any of the specified prefixes."""
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str))
def has_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
def has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
"""Returns true if the state dict has any keys ending with any of the specified suffixes."""
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str))
@@ -408,14 +408,63 @@ class ModelConfigBase(ABC, BaseModel):
@classmethod
def get_tag(cls) -> Tag:
"""Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized,
pydantic uses the tag to determine which subclass to instantiate.
The tag is a dot-separated string of the type, format, base and variant (if applicable).
"""
tag_strings: list[str] = []
for name in ("type", "base", "format", "variant"):
for name in ("type", "format", "base", "variant"):
if field := cls.model_fields.get(name):
if field.default is not PydanticUndefined:
# We assume each of these fields has an Enum for its default
tag_strings.append(str(field.default.value))
# We expect each of these fields has an Enum for its default; we want the value of the enum.
tag_strings.append(field.default.value)
return Tag(".".join(tag_strings))
@staticmethod
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
if isinstance(v, ModelConfigBase):
# We have an instance of a ModelConfigBase subclass - use its tag directly.
return v.get_tag().tag
if isinstance(v, dict):
# We have a dict - compute the tag from its fields.
tag_strings: list[str] = []
if type_ := v.get("type"):
if isinstance(type_, Enum):
type_ = type_.value
tag_strings.append(type_)
if format_ := v.get("format"):
if isinstance(format_, Enum):
format_ = format_.value
tag_strings.append(format_)
if base_ := v.get("base"):
if isinstance(base_, Enum):
base_ = base_.value
tag_strings.append(base_)
# Special case: CLIP Embed models also need the variant to distinguish them.
if (
type_ == ModelType.CLIPEmbed.value
and format_ == ModelFormat.Diffusers.value
and base_ == BaseModelType.Any.value
):
if variant_value := v.get("variant"):
if isinstance(variant_value, Enum):
variant_value = variant_value.value
tag_strings.append(variant_value)
else:
raise ValueError("CLIP Embed model config dict must include a 'variant' field")
return ".".join(tag_strings)
else:
raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
"""Given the model on disk and any overrides, return an instance of this config class.
@@ -536,7 +585,7 @@ class T5Encoder_BnBLLMint8_Config(ModelConfigBase):
@classmethod
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_scb_key_suffix = has_keys_ending_with(mod.load_state_dict(), "SCB")
has_scb_key_suffix = has_any_keys_ending_with(mod.load_state_dict(), "SCB")
if not has_scb_key_suffix:
raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
@@ -578,7 +627,7 @@ class LoRA_OMI_Config_Base(LoRAConfigBase):
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -643,7 +692,7 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -657,7 +706,7 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = has_keys_starting_with(
has_key_with_lora_prefix = has_any_keys_starting_with(
mod.load_state_dict(),
{
"lora_te_",
@@ -668,7 +717,7 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
},
)
has_key_with_lora_suffix = has_keys_ending_with(
has_key_with_lora_suffix = has_any_keys_ending_with(
mod.load_state_dict(),
{
"to_k_lora.up.weight",
@@ -769,7 +818,7 @@ class LoRA_Diffusers_Config_Base(LoRAConfigBase):
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -850,14 +899,14 @@ class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
if not has_keys_starting_with(
if not has_any_keys_starting_with(
mod.load_state_dict(),
{
"encoder.conv_in",
@@ -920,7 +969,7 @@ class VAE_Diffusers_Config_Base(DiffusersConfigBase):
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -992,7 +1041,7 @@ class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -1056,14 +1105,14 @@ class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConf
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
if has_keys_starting_with(
if has_any_keys_starting_with(
mod.load_state_dict(),
{
"controlnet",
@@ -1134,7 +1183,7 @@ class TI_Config_Base(ABC, BaseModel):
@classmethod
def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod, path)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -1330,7 +1379,7 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -1355,7 +1404,7 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
@classmethod
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType:
base = cls.model_fields["base"].default.value
base = cls.model_fields["base"].default
if base is BaseModelType.StableDiffusion2:
state_dict = mod.load_state_dict()
@@ -1372,7 +1421,7 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
base = cls.model_fields["base"].default.value
base = cls.model_fields["base"].default
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.0.0.weight"
@@ -1490,7 +1539,7 @@ class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelCon
@classmethod
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
if not has_keys_exact(
if not has_any_keys(
mod.load_state_dict(),
{
"double_blocks.0.img_attn.norm.key_norm.scale",
@@ -1675,7 +1724,7 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase):
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -1719,7 +1768,7 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase):
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
base = cls.model_fields["base"].default.value
base = cls.model_fields["base"].default
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
in_channels = unet_config.get("in_channels")
@@ -1882,7 +1931,7 @@ class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase):
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -1951,14 +2000,14 @@ class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase):
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
if not has_keys_starting_with(
if not has_any_keys_starting_with(
mod.load_state_dict(),
{
"image_proj.",
@@ -2035,7 +2084,10 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase):
_validate_class_name(
cls,
common_config_paths(mod.path),
{
mod.path / "config.json",
mod.path / "text_encoder" / "config.json",
},
{
"CLIPModel",
"CLIPTextModel",
@@ -2050,8 +2102,14 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase):
@classmethod
def _validate_variant(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model variant does not match this config class."""
expected_variant = cls.model_fields["variant"].default.value
config = _get_config_or_raise(cls, common_config_paths(mod.path))
expected_variant = cls.model_fields["variant"].default
config = _get_config_or_raise(
cls,
{
mod.path / "config.json",
mod.path / "text_encoder" / "config.json",
},
)
recognized_variant = _get_clip_variant_type_from_config(config)
if recognized_variant is None:
@@ -2120,7 +2178,7 @@ class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@@ -2294,24 +2352,6 @@ class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelC
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
if isinstance(v, ModelConfigBase):
return v.get_tag().tag
tag_strings: list[str] = []
for name in ("type", "base", "format", "variant"):
field_value = v.get(name)
if isinstance(field_value, Enum):
field_value = field_value.value
if field_value is not None:
tag_strings.append(field_value)
return ".".join(tag_strings)
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
AnyModelConfig = Annotated[
@@ -2407,7 +2447,7 @@ AnyModelConfig = Annotated[
# Unknown model (fallback)
Annotated[Unknown_Config, Unknown_Config.get_tag()],
],
Discriminator(get_model_discriminator_value),
Discriminator(ModelConfigBase.get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
@@ -2513,19 +2553,34 @@ class ModelConfigFactory:
matches = [r for r in results.values() if isinstance(r, ModelConfigBase)]
if not matches and app_config.allow_unknown_models:
logger.warning(f"Unable to identify model {mod.name}, classifying as UnknownModelConfig")
logger.debug(f"Model matching results: {results}")
logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config")
return Unknown_Config(**fields)
instance = next(iter(matches))
if len(matches) > 1:
# TODO(psyche): When we get multiple matches, at most only 1 will be correct. We should disambiguate the
# matches, probably on a case-by-case basis.
# We have multiple matches, in which case at most 1 is correct. We need to pick one.
#
# One known case is certain SD main (pipeline) models can look like a LoRA. This could happen if the model
# contains merged in LoRA weights.
# Known cases:
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
# a config.json file. Prefer the main model.
# Sort the matching according to known special cases.
def sort_key(m: AnyModelConfig) -> int:
match m.type:
case ModelType.Main:
return 0
case ModelType.LoRA:
return 1
case ModelType.CLIPEmbed:
return 2
case _:
return 3
matches.sort(key=sort_key)
logger.warning(
f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(instance).__name__}."
f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}."
)
instance = matches[0]
logger.info(f"Model {mod.name} classified as {type(instance).__name__}")
return instance

View File

@@ -16,12 +16,12 @@ from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
DiffusersConfigBase,
Main_Checkpoint_SD1_Config,
Main_Diffusers_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Diffusers_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Diffusers_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
Main_Diffusers_SDXL_Config,
Main_Diffusers_SDXLRefiner_Config,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key