From 165f57286aea9bfeeee3cc51913baa0792e02e38 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 1 Oct 2025 19:47:57 +1000 Subject: [PATCH] fix(mm): tag generation & scattered probe fixes --- invokeai/backend/model_manager/config.py | 177 ++++++++++++------ .../load/model_loaders/stable_diffusion.py | 6 +- 2 files changed, 119 insertions(+), 64 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 442f9731a1..c82b0673bd 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -146,19 +146,19 @@ class FieldValidator: return FieldValidator.get_validator(model, field_name).validate_python(value) -def has_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool: - """Returns true if the state dict has all of the specified keys.""" +def has_any_keys(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool: + """Returns true if the state dict has any of the specified keys.""" _keys = {keys} if isinstance(keys, str) else keys - return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)}) + return any(key in state_dict for key in _keys) -def has_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool: +def has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool: """Returns true if the state dict has any keys starting with any of the specified prefixes.""" _prefixes = {prefixes} if isinstance(prefixes, str) else prefixes return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)) -def has_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool: +def has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool: """Returns true if the state dict has any keys ending with any of the specified suffixes.""" _suffixes = {suffixes} if isinstance(suffixes, str) else suffixes return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)) @@ -408,14 +408,63 @@ class ModelConfigBase(ABC, BaseModel): @classmethod def get_tag(cls) -> Tag: + """Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized, + pydantic uses the tag to determine which subclass to instantiate. + + The tag is a dot-separated string of the type, format, base and variant (if applicable). + """ tag_strings: list[str] = [] - for name in ("type", "base", "format", "variant"): + for name in ("type", "format", "base", "variant"): if field := cls.model_fields.get(name): if field.default is not PydanticUndefined: - # We assume each of these fields has an Enum for its default - tag_strings.append(str(field.default.value)) + # We expect each of these fields has an Enum for its default; we want the value of the enum. + tag_strings.append(field.default.value) return Tag(".".join(tag_strings)) + @staticmethod + def get_model_discriminator_value(v: Any) -> str: + """ + Computes the discriminator value for a model config. + https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator + """ + if isinstance(v, ModelConfigBase): + # We have an instance of a ModelConfigBase subclass - use its tag directly. + return v.get_tag().tag + if isinstance(v, dict): + # We have a dict - compute the tag from its fields. + tag_strings: list[str] = [] + if type_ := v.get("type"): + if isinstance(type_, Enum): + type_ = type_.value + tag_strings.append(type_) + + if format_ := v.get("format"): + if isinstance(format_, Enum): + format_ = format_.value + tag_strings.append(format_) + + if base_ := v.get("base"): + if isinstance(base_, Enum): + base_ = base_.value + tag_strings.append(base_) + + # Special case: CLIP Embed models also need the variant to distinguish them. + if ( + type_ == ModelType.CLIPEmbed.value + and format_ == ModelFormat.Diffusers.value + and base_ == BaseModelType.Any.value + ): + if variant_value := v.get("variant"): + if isinstance(variant_value, Enum): + variant_value = variant_value.value + tag_strings.append(variant_value) + else: + raise ValueError("CLIP Embed model config dict must include a 'variant' field") + + return ".".join(tag_strings) + else: + raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance") + @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: """Given the model on disk and any overrides, return an instance of this config class. @@ -536,7 +585,7 @@ class T5Encoder_BnBLLMint8_Config(ModelConfigBase): @classmethod def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: - has_scb_key_suffix = has_keys_ending_with(mod.load_state_dict(), "SCB") + has_scb_key_suffix = has_any_keys_ending_with(mod.load_state_dict(), "SCB") if not has_scb_key_suffix: raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8") @@ -578,7 +627,7 @@ class LoRA_OMI_Config_Base(LoRAConfigBase): @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -643,7 +692,7 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase): @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -657,7 +706,7 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase): # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. # Some main models have these keys, likely due to the creator merging in a LoRA. - has_key_with_lora_prefix = has_keys_starting_with( + has_key_with_lora_prefix = has_any_keys_starting_with( mod.load_state_dict(), { "lora_te_", @@ -668,7 +717,7 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase): }, ) - has_key_with_lora_suffix = has_keys_ending_with( + has_key_with_lora_suffix = has_any_keys_ending_with( mod.load_state_dict(), { "to_k_lora.up.weight", @@ -769,7 +818,7 @@ class LoRA_Diffusers_Config_Base(LoRAConfigBase): @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -850,14 +899,14 @@ class VAE_Checkpoint_Config_Base(CheckpointConfigBase): @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @classmethod def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None: - if not has_keys_starting_with( + if not has_any_keys_starting_with( mod.load_state_dict(), { "encoder.conv_in", @@ -920,7 +969,7 @@ class VAE_Diffusers_Config_Base(DiffusersConfigBase): @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -992,7 +1041,7 @@ class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -1056,14 +1105,14 @@ class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConf @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @classmethod def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None: - if has_keys_starting_with( + if has_any_keys_starting_with( mod.load_state_dict(), { "controlnet", @@ -1134,7 +1183,7 @@ class TI_Config_Base(ABC, BaseModel): @classmethod def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod, path) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -1330,7 +1379,7 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase): @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -1355,7 +1404,7 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase): @classmethod def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: - base = cls.model_fields["base"].default.value + base = cls.model_fields["base"].default if base is BaseModelType.StableDiffusion2: state_dict = mod.load_state_dict() @@ -1372,7 +1421,7 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase): @classmethod def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: - base = cls.model_fields["base"].default.value + base = cls.model_fields["base"].default state_dict = mod.load_state_dict() key_name = "model.diffusion_model.input_blocks.0.0.weight" @@ -1490,7 +1539,7 @@ class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelCon @classmethod def _validate_is_flux(cls, mod: ModelOnDisk) -> None: - if not has_keys_exact( + if not has_any_keys( mod.load_state_dict(), { "double_blocks.0.img_attn.norm.key_norm.scale", @@ -1675,7 +1724,7 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase): @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -1719,7 +1768,7 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase): @classmethod def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: - base = cls.model_fields["base"].default.value + base = cls.model_fields["base"].default unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json") in_channels = unet_config.get("in_channels") @@ -1882,7 +1931,7 @@ class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase): @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -1951,14 +2000,14 @@ class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase): @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @classmethod def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None: - if not has_keys_starting_with( + if not has_any_keys_starting_with( mod.load_state_dict(), { "image_proj.", @@ -2035,7 +2084,10 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase): _validate_class_name( cls, - common_config_paths(mod.path), + { + mod.path / "config.json", + mod.path / "text_encoder" / "config.json", + }, { "CLIPModel", "CLIPTextModel", @@ -2050,8 +2102,14 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase): @classmethod def _validate_variant(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model variant does not match this config class.""" - expected_variant = cls.model_fields["variant"].default.value - config = _get_config_or_raise(cls, common_config_paths(mod.path)) + expected_variant = cls.model_fields["variant"].default + config = _get_config_or_raise( + cls, + { + mod.path / "config.json", + mod.path / "text_encoder" / "config.json", + }, + ) recognized_variant = _get_clip_variant_type_from_config(config) if recognized_variant is None: @@ -2120,7 +2178,7 @@ class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig @classmethod def _validate_base(cls, mod: ModelOnDisk) -> None: """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default.value + expected_base = cls.model_fields["base"].default recognized_base = cls._get_base_or_raise(mod) if expected_base is not recognized_base: raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @@ -2294,24 +2352,6 @@ class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelC base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) -def get_model_discriminator_value(v: Any) -> str: - """ - Computes the discriminator value for a model config. - https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator - """ - if isinstance(v, ModelConfigBase): - return v.get_tag().tag - - tag_strings: list[str] = [] - for name in ("type", "base", "format", "variant"): - field_value = v.get(name) - if isinstance(field_value, Enum): - field_value = field_value.value - if field_value is not None: - tag_strings.append(field_value) - return ".".join(tag_strings) - - # The types are listed explicitly because IDEs/LSPs can't identify the correct types # when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes AnyModelConfig = Annotated[ @@ -2407,7 +2447,7 @@ AnyModelConfig = Annotated[ # Unknown model (fallback) Annotated[Unknown_Config, Unknown_Config.get_tag()], ], - Discriminator(get_model_discriminator_value), + Discriminator(ModelConfigBase.get_model_discriminator_value), ] AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) @@ -2513,19 +2553,34 @@ class ModelConfigFactory: matches = [r for r in results.values() if isinstance(r, ModelConfigBase)] if not matches and app_config.allow_unknown_models: - logger.warning(f"Unable to identify model {mod.name}, classifying as UnknownModelConfig") - logger.debug(f"Model matching results: {results}") + logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config") return Unknown_Config(**fields) - instance = next(iter(matches)) if len(matches) > 1: - # TODO(psyche): When we get multiple matches, at most only 1 will be correct. We should disambiguate the - # matches, probably on a case-by-case basis. + # We have multiple matches, in which case at most 1 is correct. We need to pick one. # - # One known case is certain SD main (pipeline) models can look like a LoRA. This could happen if the model - # contains merged in LoRA weights. + # Known cases: + # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model. + # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with + # a config.json file. Prefer the main model. + + # Sort the matching according to known special cases. + def sort_key(m: AnyModelConfig) -> int: + match m.type: + case ModelType.Main: + return 0 + case ModelType.LoRA: + return 1 + case ModelType.CLIPEmbed: + return 2 + case _: + return 3 + + matches.sort(key=sort_key) logger.warning( - f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(instance).__name__}." + f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}." ) + + instance = matches[0] logger.info(f"Model {mod.name} classified as {type(instance).__name__}") return instance diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index 5b7791a4ba..c8c751134c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -16,12 +16,12 @@ from invokeai.backend.model_manager.config import ( CheckpointConfigBase, DiffusersConfigBase, Main_Checkpoint_SD1_Config, - Main_Diffusers_SD1_Config, Main_Checkpoint_SD2_Config, - Main_Diffusers_SD2_Config, Main_Checkpoint_SDXL_Config, - Main_Diffusers_SDXL_Config, Main_Checkpoint_SDXLRefiner_Config, + Main_Diffusers_SD1_Config, + Main_Diffusers_SD2_Config, + Main_Diffusers_SDXL_Config, Main_Diffusers_SDXLRefiner_Config, ) from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key