mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-04 11:35:02 -05:00
53 lines
2.1 KiB
Python
53 lines
2.1 KiB
Python
from typing import Any, Dict
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterParams
|
|
|
|
|
|
def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool:
|
|
"""Is the state dict for an XLabs FLUX IP-Adapter model?
|
|
|
|
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
|
"""
|
|
# If all of the expected keys are present, then this is very likely an XLabs IP-Adapter model.
|
|
expected_keys = {
|
|
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias",
|
|
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight",
|
|
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias",
|
|
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.weight",
|
|
"ip_adapter_proj_model.norm.bias",
|
|
"ip_adapter_proj_model.norm.weight",
|
|
"ip_adapter_proj_model.proj.bias",
|
|
"ip_adapter_proj_model.proj.weight",
|
|
}
|
|
|
|
if expected_keys.issubset(sd.keys()):
|
|
return True
|
|
return False
|
|
|
|
|
|
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams:
|
|
num_double_blocks = 0
|
|
context_dim = 0
|
|
hidden_dim = 0
|
|
|
|
# Count the number of double blocks.
|
|
double_block_index = 0
|
|
while f"double_blocks.{double_block_index}.processor.ip_adapter_double_stream_k_proj.weight" in state_dict:
|
|
double_block_index += 1
|
|
num_double_blocks = double_block_index
|
|
|
|
hidden_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[0]
|
|
context_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[1]
|
|
clip_embeddings_dim = state_dict["ip_adapter_proj_model.proj.weight"].shape[1]
|
|
clip_extra_context_tokens = state_dict["ip_adapter_proj_model.proj.weight"].shape[0] // context_dim
|
|
|
|
return XlabsIpAdapterParams(
|
|
num_double_blocks=num_double_blocks,
|
|
context_dim=context_dim,
|
|
hidden_dim=hidden_dim,
|
|
clip_embeddings_dim=clip_embeddings_dim,
|
|
clip_extra_context_tokens=clip_extra_context_tokens,
|
|
)
|