feat(mm): port ip adapter to new api

This commit is contained in:
psychedelicious
2025-09-24 20:08:08 +10:00
parent ab2b1b2bde
commit b74e0f6ca4
2 changed files with 136 additions and 11 deletions

View File

@@ -1,11 +1,11 @@
from typing import Any, Dict
from typing import Any
import torch
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterParams
def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool:
def is_state_dict_xlabs_ip_adapter(sd: dict[str | int, Any]) -> bool:
"""Is the state dict for an XLabs FLUX IP-Adapter model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
@@ -27,7 +27,7 @@ def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool:
return False
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams:
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str | int, torch.Tensor]) -> XlabsIpAdapterParams:
num_double_blocks = 0
context_dim = 0
hidden_dim = 0

View File

@@ -45,6 +45,7 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.misc import uuid_string
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
@@ -149,7 +150,8 @@ def _validate_overrides(
Args:
config_class: The config class that is being tested.
provided_overrides: The overrides provided by the user.
valid_overrides: The overrides that are valid for this config class.
valid_overrides: The overrides that are valid for this config class. The value can be a specific value or a
callable that takes the provided value and returns True if it is valid.
Returns:
True if all provided overrides match the valid overrides, False if some valid overrides are missing.
@@ -158,14 +160,21 @@ def _validate_overrides(
NotAMatch if any override does not match the allowed value.
"""
is_perfect_match = True
for key, value in valid_overrides.items():
if key not in provided_overrides:
for override_name, constraint in valid_overrides.items():
if override_name not in provided_overrides:
is_perfect_match = False
continue
if provided_overrides[key] != value:
# Handle the typical case where the constraint is a specific value
if provided_overrides[override_name] != constraint:
raise NotAMatch(
config_class,
f"override {key}={provided_overrides[key]} does not match required value {key}={value}",
f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}={constraint}",
)
# Handle the less common case where the constraint is a callable
elif callable(constraint) and not constraint(provided_overrides[override_name]):
raise NotAMatch(
config_class,
f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}=callable",
)
return is_perfect_match
@@ -521,7 +530,6 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
# LyCORIS LoRAs are always files, never directories
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
@@ -895,20 +903,137 @@ class IPAdapterConfigBase(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
IPAdapterInvokeAIConfigBaseTypes: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
]
"""Helper TypeAlias for valid base types for IP Adapter models in the InvokeAI format."""
ip_adapter_invoke_ai_base_type_adapter = TypeAdapter[IPAdapterInvokeAIConfigBaseTypes](IPAdapterInvokeAIConfigBaseTypes)
"""Helper TypeAdapter for IP Adapter InvokeAI base types."""
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
"""Model config for IP Adapter diffusers format models."""
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
base: IPAdapterInvokeAIConfigBaseTypes = Field(...)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.IPAdapter,
"format": ModelFormat.InvokeAI,
"base": ip_adapter_invoke_ai_base_type_adapter.validate_python,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
weights_file = mod.path / "ip_adapter.bin"
if not weights_file.exists():
raise NotAMatch(cls, "missing ip_adapter.bin weights file")
image_encoder_metadata_file = mod.path / "image_encoder.txt"
if not image_encoder_metadata_file.exists():
raise NotAMatch(cls, "missing image_encoder.txt metadata file")
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAIConfigBaseTypes:
state_dict = mod.load_state_dict()
try:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
except Exception as e:
raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e
match cross_attention_dim:
case 1280:
return BaseModelType.StableDiffusionXL
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case _:
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
IPAdapterCheckpointConfigBaseTypes: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]
"""Helper TypeAlias for valid base types for IP Adapter models in the Checkpoint format."""
ip_adapter_checkpoint_base_type_adapter = TypeAdapter[IPAdapterCheckpointConfigBaseTypes](
IPAdapterCheckpointConfigBaseTypes
)
"""Helper TypeAdapter for IP Adapter Checkpoint base types."""
class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
VALID_OVERRIDES: ClassVar = {
"type": ModelType.IPAdapter,
"format": ModelFormat.Checkpoint,
"base": ip_adapter_checkpoint_base_type_adapter.validate_python,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
if not mod.has_keys_starting_with(
{
"image_proj.",
"ip_adapter.",
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
"ip_adapter_proj_model.",
}
):
raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics")
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpointConfigBaseTypes:
state_dict = mod.load_state_dict()
if is_state_dict_xlabs_ip_adapter(state_dict):
return BaseModelType.Flux
try:
cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1]
except Exception as e:
raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e
match cross_attention_dim:
case 1280:
return BaseModelType.StableDiffusionXL
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case _:
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""