Add model probing for XLabs FLUX IP-Adapter.

This commit is contained in:
Ryan Dick
2024-10-15 14:58:04 +00:00
parent f939dbdc33
commit 412e79d8e6

View File

@@ -14,6 +14,7 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
@@ -243,8 +244,6 @@ class ModelProbe(object):
"cond_stage_model.",
"first_stage_model.",
"model.diffusion_model.",
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix.
"double_blocks.",
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
# This prefix is typically used to distinguish between multiple models bundled in a single file.
"model.diffusion_model.double_blocks.",
@@ -252,6 +251,10 @@ class ModelProbe(object):
):
# Keys starting with double_blocks are associated with Flux models
return ModelType.Main
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
# careful to avoid false positives on XLabs FLUX IP-Adapter models.
elif key.startswith("double_blocks.") and "ip_adapter" not in key:
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
@@ -274,7 +277,14 @@ class ModelProbe(object):
)
):
return ModelType.ControlNet
elif key.startswith(("image_proj.", "ip_adapter.")):
elif key.startswith(
(
"image_proj.",
"ip_adapter.",
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
"ip_adapter_proj_model.",
)
):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
@@ -672,6 +682,10 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if is_state_dict_xlabs_ip_adapter(checkpoint):
return BaseModelType.Flux
for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")):
continue