mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add model probing for XLabs FLUX IP-Adapter.
This commit is contained in:
@@ -14,6 +14,7 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
@@ -243,8 +244,6 @@ class ModelProbe(object):
|
||||
"cond_stage_model.",
|
||||
"first_stage_model.",
|
||||
"model.diffusion_model.",
|
||||
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix.
|
||||
"double_blocks.",
|
||||
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
|
||||
# This prefix is typically used to distinguish between multiple models bundled in a single file.
|
||||
"model.diffusion_model.double_blocks.",
|
||||
@@ -252,6 +251,10 @@ class ModelProbe(object):
|
||||
):
|
||||
# Keys starting with double_blocks are associated with Flux models
|
||||
return ModelType.Main
|
||||
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
|
||||
# careful to avoid false positives on XLabs FLUX IP-Adapter models.
|
||||
elif key.startswith("double_blocks.") and "ip_adapter" not in key:
|
||||
return ModelType.Main
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
@@ -274,7 +277,14 @@ class ModelProbe(object):
|
||||
)
|
||||
):
|
||||
return ModelType.ControlNet
|
||||
elif key.startswith(("image_proj.", "ip_adapter.")):
|
||||
elif key.startswith(
|
||||
(
|
||||
"image_proj.",
|
||||
"ip_adapter.",
|
||||
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
|
||||
"ip_adapter_proj_model.",
|
||||
)
|
||||
):
|
||||
return ModelType.IPAdapter
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
@@ -672,6 +682,10 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
|
||||
if is_state_dict_xlabs_ip_adapter(checkpoint):
|
||||
return BaseModelType.Flux
|
||||
|
||||
for key in checkpoint.keys():
|
||||
if not key.startswith(("image_proj.", "ip_adapter.")):
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user