mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): port ip adapter to new api
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterParams
|
||||
|
||||
|
||||
def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool:
|
||||
def is_state_dict_xlabs_ip_adapter(sd: dict[str | int, Any]) -> bool:
|
||||
"""Is the state dict for an XLabs FLUX IP-Adapter model?
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
@@ -27,7 +27,7 @@ def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams:
|
||||
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str | int, torch.Tensor]) -> XlabsIpAdapterParams:
|
||||
num_double_blocks = 0
|
||||
context_dim = 0
|
||||
hidden_dim = 0
|
||||
|
||||
@@ -45,6 +45,7 @@ from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
|
||||
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
|
||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
@@ -149,7 +150,8 @@ def _validate_overrides(
|
||||
Args:
|
||||
config_class: The config class that is being tested.
|
||||
provided_overrides: The overrides provided by the user.
|
||||
valid_overrides: The overrides that are valid for this config class.
|
||||
valid_overrides: The overrides that are valid for this config class. The value can be a specific value or a
|
||||
callable that takes the provided value and returns True if it is valid.
|
||||
|
||||
Returns:
|
||||
True if all provided overrides match the valid overrides, False if some valid overrides are missing.
|
||||
@@ -158,14 +160,21 @@ def _validate_overrides(
|
||||
NotAMatch if any override does not match the allowed value.
|
||||
"""
|
||||
is_perfect_match = True
|
||||
for key, value in valid_overrides.items():
|
||||
if key not in provided_overrides:
|
||||
for override_name, constraint in valid_overrides.items():
|
||||
if override_name not in provided_overrides:
|
||||
is_perfect_match = False
|
||||
continue
|
||||
if provided_overrides[key] != value:
|
||||
# Handle the typical case where the constraint is a specific value
|
||||
if provided_overrides[override_name] != constraint:
|
||||
raise NotAMatch(
|
||||
config_class,
|
||||
f"override {key}={provided_overrides[key]} does not match required value {key}={value}",
|
||||
f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}={constraint}",
|
||||
)
|
||||
# Handle the less common case where the constraint is a callable
|
||||
elif callable(constraint) and not constraint(provided_overrides[override_name]):
|
||||
raise NotAMatch(
|
||||
config_class,
|
||||
f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}=callable",
|
||||
)
|
||||
|
||||
return is_perfect_match
|
||||
@@ -521,7 +530,6 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
# LyCORIS LoRAs are always files, never directories
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
@@ -895,20 +903,137 @@ class IPAdapterConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||
|
||||
|
||||
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
IPAdapterInvokeAIConfigBaseTypes: TypeAlias = Literal[
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
]
|
||||
"""Helper TypeAlias for valid base types for IP Adapter models in the InvokeAI format."""
|
||||
|
||||
ip_adapter_invoke_ai_base_type_adapter = TypeAdapter[IPAdapterInvokeAIConfigBaseTypes](IPAdapterInvokeAIConfigBaseTypes)
|
||||
"""Helper TypeAdapter for IP Adapter InvokeAI base types."""
|
||||
|
||||
|
||||
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
"""Model config for IP Adapter diffusers format models."""
|
||||
|
||||
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
|
||||
# time. Need to go through the history to make sure I'm understanding this fully.
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
|
||||
base: IPAdapterInvokeAIConfigBaseTypes = Field(...)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.IPAdapter,
|
||||
"format": ModelFormat.InvokeAI,
|
||||
"base": ip_adapter_invoke_ai_base_type_adapter.validate_python,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
weights_file = mod.path / "ip_adapter.bin"
|
||||
if not weights_file.exists():
|
||||
raise NotAMatch(cls, "missing ip_adapter.bin weights file")
|
||||
|
||||
image_encoder_metadata_file = mod.path / "image_encoder.txt"
|
||||
if not image_encoder_metadata_file.exists():
|
||||
raise NotAMatch(cls, "missing image_encoder.txt metadata file")
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAIConfigBaseTypes:
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
try:
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e
|
||||
|
||||
match cross_attention_dim:
|
||||
case 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
case 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
case 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
IPAdapterCheckpointConfigBaseTypes: TypeAlias = Literal[
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.Flux,
|
||||
]
|
||||
"""Helper TypeAlias for valid base types for IP Adapter models in the Checkpoint format."""
|
||||
|
||||
ip_adapter_checkpoint_base_type_adapter = TypeAdapter[IPAdapterCheckpointConfigBaseTypes](
|
||||
IPAdapterCheckpointConfigBaseTypes
|
||||
)
|
||||
"""Helper TypeAdapter for IP Adapter Checkpoint base types."""
|
||||
|
||||
|
||||
class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
"""Model config for IP Adapter checkpoint format models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.IPAdapter,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
"base": ip_adapter_checkpoint_base_type_adapter.validate_python,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
if not mod.has_keys_starting_with(
|
||||
{
|
||||
"image_proj.",
|
||||
"ip_adapter.",
|
||||
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
|
||||
"ip_adapter_proj_model.",
|
||||
}
|
||||
):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics")
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpointConfigBaseTypes:
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
if is_state_dict_xlabs_ip_adapter(state_dict):
|
||||
return BaseModelType.Flux
|
||||
|
||||
try:
|
||||
cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1]
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e
|
||||
|
||||
match cross_attention_dim:
|
||||
case 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
case 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
case 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
|
||||
|
||||
|
||||
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for Clip Embeddings."""
|
||||
|
||||
Reference in New Issue
Block a user