mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
fix(mm): tag generation & scattered probe fixes
This commit is contained in:
@@ -146,19 +146,19 @@ class FieldValidator:
|
||||
return FieldValidator.get_validator(model, field_name).validate_python(value)
|
||||
|
||||
|
||||
def has_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
|
||||
"""Returns true if the state dict has all of the specified keys."""
|
||||
def has_any_keys(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
|
||||
"""Returns true if the state dict has any of the specified keys."""
|
||||
_keys = {keys} if isinstance(keys, str) else keys
|
||||
return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
|
||||
return any(key in state_dict for key in _keys)
|
||||
|
||||
|
||||
def has_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
|
||||
def has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
|
||||
"""Returns true if the state dict has any keys starting with any of the specified prefixes."""
|
||||
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
|
||||
return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str))
|
||||
|
||||
|
||||
def has_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
|
||||
def has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
|
||||
"""Returns true if the state dict has any keys ending with any of the specified suffixes."""
|
||||
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
|
||||
return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str))
|
||||
@@ -408,14 +408,63 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
"""Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized,
|
||||
pydantic uses the tag to determine which subclass to instantiate.
|
||||
|
||||
The tag is a dot-separated string of the type, format, base and variant (if applicable).
|
||||
"""
|
||||
tag_strings: list[str] = []
|
||||
for name in ("type", "base", "format", "variant"):
|
||||
for name in ("type", "format", "base", "variant"):
|
||||
if field := cls.model_fields.get(name):
|
||||
if field.default is not PydanticUndefined:
|
||||
# We assume each of these fields has an Enum for its default
|
||||
tag_strings.append(str(field.default.value))
|
||||
# We expect each of these fields has an Enum for its default; we want the value of the enum.
|
||||
tag_strings.append(field.default.value)
|
||||
return Tag(".".join(tag_strings))
|
||||
|
||||
@staticmethod
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
if isinstance(v, ModelConfigBase):
|
||||
# We have an instance of a ModelConfigBase subclass - use its tag directly.
|
||||
return v.get_tag().tag
|
||||
if isinstance(v, dict):
|
||||
# We have a dict - compute the tag from its fields.
|
||||
tag_strings: list[str] = []
|
||||
if type_ := v.get("type"):
|
||||
if isinstance(type_, Enum):
|
||||
type_ = type_.value
|
||||
tag_strings.append(type_)
|
||||
|
||||
if format_ := v.get("format"):
|
||||
if isinstance(format_, Enum):
|
||||
format_ = format_.value
|
||||
tag_strings.append(format_)
|
||||
|
||||
if base_ := v.get("base"):
|
||||
if isinstance(base_, Enum):
|
||||
base_ = base_.value
|
||||
tag_strings.append(base_)
|
||||
|
||||
# Special case: CLIP Embed models also need the variant to distinguish them.
|
||||
if (
|
||||
type_ == ModelType.CLIPEmbed.value
|
||||
and format_ == ModelFormat.Diffusers.value
|
||||
and base_ == BaseModelType.Any.value
|
||||
):
|
||||
if variant_value := v.get("variant"):
|
||||
if isinstance(variant_value, Enum):
|
||||
variant_value = variant_value.value
|
||||
tag_strings.append(variant_value)
|
||||
else:
|
||||
raise ValueError("CLIP Embed model config dict must include a 'variant' field")
|
||||
|
||||
return ".".join(tag_strings)
|
||||
else:
|
||||
raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
"""Given the model on disk and any overrides, return an instance of this config class.
|
||||
@@ -536,7 +585,7 @@ class T5Encoder_BnBLLMint8_Config(ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
has_scb_key_suffix = has_keys_ending_with(mod.load_state_dict(), "SCB")
|
||||
has_scb_key_suffix = has_any_keys_ending_with(mod.load_state_dict(), "SCB")
|
||||
if not has_scb_key_suffix:
|
||||
raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
|
||||
|
||||
@@ -578,7 +627,7 @@ class LoRA_OMI_Config_Base(LoRAConfigBase):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -643,7 +692,7 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -657,7 +706,7 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
|
||||
|
||||
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
|
||||
# Some main models have these keys, likely due to the creator merging in a LoRA.
|
||||
has_key_with_lora_prefix = has_keys_starting_with(
|
||||
has_key_with_lora_prefix = has_any_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"lora_te_",
|
||||
@@ -668,7 +717,7 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
|
||||
},
|
||||
)
|
||||
|
||||
has_key_with_lora_suffix = has_keys_ending_with(
|
||||
has_key_with_lora_suffix = has_any_keys_ending_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"to_k_lora.up.weight",
|
||||
@@ -769,7 +818,7 @@ class LoRA_Diffusers_Config_Base(LoRAConfigBase):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -850,14 +899,14 @@ class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
|
||||
if not has_keys_starting_with(
|
||||
if not has_any_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"encoder.conv_in",
|
||||
@@ -920,7 +969,7 @@ class VAE_Diffusers_Config_Base(DiffusersConfigBase):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -992,7 +1041,7 @@ class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -1056,14 +1105,14 @@ class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConf
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
|
||||
if has_keys_starting_with(
|
||||
if has_any_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"controlnet",
|
||||
@@ -1134,7 +1183,7 @@ class TI_Config_Base(ABC, BaseModel):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod, path)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -1330,7 +1379,7 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -1355,7 +1404,7 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType:
|
||||
base = cls.model_fields["base"].default.value
|
||||
base = cls.model_fields["base"].default
|
||||
|
||||
if base is BaseModelType.StableDiffusion2:
|
||||
state_dict = mod.load_state_dict()
|
||||
@@ -1372,7 +1421,7 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
|
||||
base = cls.model_fields["base"].default.value
|
||||
base = cls.model_fields["base"].default
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
key_name = "model.diffusion_model.input_blocks.0.0.weight"
|
||||
@@ -1490,7 +1539,7 @@ class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelCon
|
||||
|
||||
@classmethod
|
||||
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
|
||||
if not has_keys_exact(
|
||||
if not has_any_keys(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
@@ -1675,7 +1724,7 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -1719,7 +1768,7 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
|
||||
base = cls.model_fields["base"].default.value
|
||||
base = cls.model_fields["base"].default
|
||||
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
|
||||
in_channels = unet_config.get("in_channels")
|
||||
|
||||
@@ -1882,7 +1931,7 @@ class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -1951,14 +2000,14 @@ class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase):
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
|
||||
if not has_keys_starting_with(
|
||||
if not has_any_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"image_proj.",
|
||||
@@ -2035,7 +2084,10 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase):
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
mod.path / "config.json",
|
||||
mod.path / "text_encoder" / "config.json",
|
||||
},
|
||||
{
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
@@ -2050,8 +2102,14 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase):
|
||||
@classmethod
|
||||
def _validate_variant(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model variant does not match this config class."""
|
||||
expected_variant = cls.model_fields["variant"].default.value
|
||||
config = _get_config_or_raise(cls, common_config_paths(mod.path))
|
||||
expected_variant = cls.model_fields["variant"].default
|
||||
config = _get_config_or_raise(
|
||||
cls,
|
||||
{
|
||||
mod.path / "config.json",
|
||||
mod.path / "text_encoder" / "config.json",
|
||||
},
|
||||
)
|
||||
recognized_variant = _get_clip_variant_type_from_config(config)
|
||||
|
||||
if recognized_variant is None:
|
||||
@@ -2120,7 +2178,7 @@ class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
expected_base = cls.model_fields["base"].default
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
@@ -2294,24 +2352,6 @@ class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelC
|
||||
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
if isinstance(v, ModelConfigBase):
|
||||
return v.get_tag().tag
|
||||
|
||||
tag_strings: list[str] = []
|
||||
for name in ("type", "base", "format", "variant"):
|
||||
field_value = v.get(name)
|
||||
if isinstance(field_value, Enum):
|
||||
field_value = field_value.value
|
||||
if field_value is not None:
|
||||
tag_strings.append(field_value)
|
||||
return ".".join(tag_strings)
|
||||
|
||||
|
||||
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
|
||||
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
|
||||
AnyModelConfig = Annotated[
|
||||
@@ -2407,7 +2447,7 @@ AnyModelConfig = Annotated[
|
||||
# Unknown model (fallback)
|
||||
Annotated[Unknown_Config, Unknown_Config.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
Discriminator(ModelConfigBase.get_model_discriminator_value),
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
|
||||
@@ -2513,19 +2553,34 @@ class ModelConfigFactory:
|
||||
matches = [r for r in results.values() if isinstance(r, ModelConfigBase)]
|
||||
|
||||
if not matches and app_config.allow_unknown_models:
|
||||
logger.warning(f"Unable to identify model {mod.name}, classifying as UnknownModelConfig")
|
||||
logger.debug(f"Model matching results: {results}")
|
||||
logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config")
|
||||
return Unknown_Config(**fields)
|
||||
|
||||
instance = next(iter(matches))
|
||||
if len(matches) > 1:
|
||||
# TODO(psyche): When we get multiple matches, at most only 1 will be correct. We should disambiguate the
|
||||
# matches, probably on a case-by-case basis.
|
||||
# We have multiple matches, in which case at most 1 is correct. We need to pick one.
|
||||
#
|
||||
# One known case is certain SD main (pipeline) models can look like a LoRA. This could happen if the model
|
||||
# contains merged in LoRA weights.
|
||||
# Known cases:
|
||||
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
|
||||
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
|
||||
# a config.json file. Prefer the main model.
|
||||
|
||||
# Sort the matching according to known special cases.
|
||||
def sort_key(m: AnyModelConfig) -> int:
|
||||
match m.type:
|
||||
case ModelType.Main:
|
||||
return 0
|
||||
case ModelType.LoRA:
|
||||
return 1
|
||||
case ModelType.CLIPEmbed:
|
||||
return 2
|
||||
case _:
|
||||
return 3
|
||||
|
||||
matches.sort(key=sort_key)
|
||||
logger.warning(
|
||||
f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(instance).__name__}."
|
||||
f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}."
|
||||
)
|
||||
|
||||
instance = matches[0]
|
||||
logger.info(f"Model {mod.name} classified as {type(instance).__name__}")
|
||||
return instance
|
||||
|
||||
@@ -16,12 +16,12 @@ from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
Main_Checkpoint_SD1_Config,
|
||||
Main_Diffusers_SD1_Config,
|
||||
Main_Checkpoint_SD2_Config,
|
||||
Main_Diffusers_SD2_Config,
|
||||
Main_Checkpoint_SDXL_Config,
|
||||
Main_Diffusers_SDXL_Config,
|
||||
Main_Checkpoint_SDXLRefiner_Config,
|
||||
Main_Diffusers_SD1_Config,
|
||||
Main_Diffusers_SD2_Config,
|
||||
Main_Diffusers_SDXL_Config,
|
||||
Main_Diffusers_SDXLRefiner_Config,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
|
||||
|
||||
Reference in New Issue
Block a user