From d185b85fb79f9d76e8c7c003b59cc56584258b2e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 24 Sep 2025 20:08:08 +1000 Subject: [PATCH] feat(mm): port ip adapter to new api --- .../flux/ip_adapter/state_dict_utils.py | 6 +- invokeai/backend/model_manager/config.py | 141 +++++++++++++++++- 2 files changed, 136 insertions(+), 11 deletions(-) diff --git a/invokeai/backend/flux/ip_adapter/state_dict_utils.py b/invokeai/backend/flux/ip_adapter/state_dict_utils.py index 90f11ff642..24ac53550f 100644 --- a/invokeai/backend/flux/ip_adapter/state_dict_utils.py +++ b/invokeai/backend/flux/ip_adapter/state_dict_utils.py @@ -1,11 +1,11 @@ -from typing import Any, Dict +from typing import Any import torch from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterParams -def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool: +def is_state_dict_xlabs_ip_adapter(sd: dict[str | int, Any]) -> bool: """Is the state dict for an XLabs FLUX IP-Adapter model? This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. @@ -27,7 +27,7 @@ def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool: return False -def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams: +def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str | int, torch.Tensor]) -> XlabsIpAdapterParams: num_double_blocks = 0 context_dim = 0 hidden_dim = 0 diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index a1e043dc35..03fb8c66e1 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -45,6 +45,7 @@ from typing_extensions import Annotated, Any, Dict from invokeai.app.services.config.config_default import get_config from invokeai.app.util.misc import uuid_string +from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux from invokeai.backend.model_hash.hash_validator import validate_hash from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS @@ -149,7 +150,8 @@ def _validate_overrides( Args: config_class: The config class that is being tested. provided_overrides: The overrides provided by the user. - valid_overrides: The overrides that are valid for this config class. + valid_overrides: The overrides that are valid for this config class. The value can be a specific value or a + callable that takes the provided value and returns True if it is valid. Returns: True if all provided overrides match the valid overrides, False if some valid overrides are missing. @@ -158,14 +160,21 @@ def _validate_overrides( NotAMatch if any override does not match the allowed value. """ is_perfect_match = True - for key, value in valid_overrides.items(): - if key not in provided_overrides: + for override_name, constraint in valid_overrides.items(): + if override_name not in provided_overrides: is_perfect_match = False continue - if provided_overrides[key] != value: + # Handle the typical case where the constraint is a specific value + if provided_overrides[override_name] != constraint: raise NotAMatch( config_class, - f"override {key}={provided_overrides[key]} does not match required value {key}={value}", + f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}={constraint}", + ) + # Handle the less common case where the constraint is a callable + elif callable(constraint) and not constraint(provided_overrides[override_name]): + raise NotAMatch( + config_class, + f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}=callable", ) return is_perfect_match @@ -521,7 +530,6 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - # LyCORIS LoRAs are always files, never directories _raise_if_not_file(cls, mod) if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): @@ -895,20 +903,137 @@ class IPAdapterConfigBase(ABC, BaseModel): type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter -class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): +IPAdapterInvokeAIConfigBaseTypes: TypeAlias = Literal[ + BaseModelType.StableDiffusion1, + BaseModelType.StableDiffusion2, + BaseModelType.StableDiffusionXL, +] +"""Helper TypeAlias for valid base types for IP Adapter models in the InvokeAI format.""" + +ip_adapter_invoke_ai_base_type_adapter = TypeAdapter[IPAdapterInvokeAIConfigBaseTypes](IPAdapterInvokeAIConfigBaseTypes) +"""Helper TypeAdapter for IP Adapter InvokeAI base types.""" + + +class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): """Model config for IP Adapter diffusers format models.""" # TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long # time. Need to go through the history to make sure I'm understanding this fully. image_encoder_model_id: str format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI + base: IPAdapterInvokeAIConfigBaseTypes = Field(...) + + VALID_OVERRIDES: ClassVar = { + "type": ModelType.IPAdapter, + "format": ModelFormat.InvokeAI, + "base": ip_adapter_invoke_ai_base_type_adapter.validate_python, + } + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _raise_if_not_dir(cls, mod) + + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) + + weights_file = mod.path / "ip_adapter.bin" + if not weights_file.exists(): + raise NotAMatch(cls, "missing ip_adapter.bin weights file") + + image_encoder_metadata_file = mod.path / "image_encoder.txt" + if not image_encoder_metadata_file.exists(): + raise NotAMatch(cls, "missing image_encoder.txt metadata file") + + base = fields.get("base") or cls._get_base_or_raise(mod) + return cls(**fields, base=base) + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAIConfigBaseTypes: + state_dict = mod.load_state_dict() + + try: + cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] + except Exception as e: + raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e + + match cross_attention_dim: + case 1280: + return BaseModelType.StableDiffusionXL + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case _: + raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") -class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): +IPAdapterCheckpointConfigBaseTypes: TypeAlias = Literal[ + BaseModelType.StableDiffusion1, + BaseModelType.StableDiffusion2, + BaseModelType.StableDiffusionXL, + BaseModelType.Flux, +] +"""Helper TypeAlias for valid base types for IP Adapter models in the Checkpoint format.""" + +ip_adapter_checkpoint_base_type_adapter = TypeAdapter[IPAdapterCheckpointConfigBaseTypes]( + IPAdapterCheckpointConfigBaseTypes +) +"""Helper TypeAdapter for IP Adapter Checkpoint base types.""" + + +class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): """Model config for IP Adapter checkpoint format models.""" format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + VALID_OVERRIDES: ClassVar = { + "type": ModelType.IPAdapter, + "format": ModelFormat.Checkpoint, + "base": ip_adapter_checkpoint_base_type_adapter.validate_python, + } + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _raise_if_not_file(cls, mod) + + if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): + return cls(**fields) + + if not mod.has_keys_starting_with( + { + "image_proj.", + "ip_adapter.", + # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.". + "ip_adapter_proj_model.", + } + ): + raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics") + + base = fields.get("base") or cls._get_base_or_raise(mod) + return cls(**fields, base=base) + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpointConfigBaseTypes: + state_dict = mod.load_state_dict() + + if is_state_dict_xlabs_ip_adapter(state_dict): + return BaseModelType.Flux + + try: + cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1] + except Exception as e: + raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e + + match cross_attention_dim: + case 1280: + return BaseModelType.StableDiffusionXL + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case _: + raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") + class CLIPEmbedDiffusersConfig(DiffusersConfigBase): """Model config for Clip Embeddings."""