Compare commits

...

16 Commits

Author SHA1 Message Date
Ryan Dick
125be54b95 Use a black image as the negative IP prompt for parity with X-Labs implementation. 2024-10-18 18:52:12 +00:00
Ryan Dick
b74dcfcef0 Test out IP-Adapter with CFG. 2024-10-16 18:11:48 +00:00
Ryan Dick
494cedf090 Merge branch 'ryan/flux-cfg' into ryan/flux-ip-adapter-cfg 2024-10-16 17:24:42 +00:00
Ryan Dick
423eb2e3b5 A bunch of HACKS to get ViT-L CLIP vision encoder working for FLUX IP-Adapter. Need to revisit how to clean this all up long term. 2024-10-16 01:41:02 +00:00
Ryan Dick
fd3338fe67 Fixes to get XLabsIpAdapterExtension running. 2024-10-16 01:39:48 +00:00
Ryan Dick
58e9c3db1c Initial draft of integrating FLUX IP-Adapter inference support. 2024-10-15 22:28:59 +00:00
Ryan Dick
c751015bf2 Add IPAdapterDoubleBlocks wrapper to tidy FLUX ip-adapter handling. 2024-10-15 15:31:26 +00:00
Ryan Dick
b91648391b (minor) Move infer_xlabs_ip_adapter_params_from_state_dict(...) to state_dict_utils.py. 2024-10-15 15:02:03 +00:00
Ryan Dick
e6e1f614a6 Add model loading code for xlabs FLUX IP-Adapter (not tested). 2024-10-15 14:58:37 +00:00
Ryan Dick
a87a83ff6f Add model probing for XLabs FLUX IP-Adapter. 2024-10-15 14:58:04 +00:00
Ryan Dick
66bb8081df Add is_state_dict_xlabs_ip_adapter() utility function. 2024-10-15 14:49:00 +00:00
Ryan Dick
05358c3ada Add logic for loading an Xlabs IP-Adapter from a state dict. 2024-10-15 13:52:07 +00:00
Ryan Dick
3f02984fab Add initial logic for inferring FLUX IP-Adapter params from a state_dict. 2024-10-11 17:15:10 +00:00
Ryan Dick
7ecc6220d4 Fixup typing/imports for IPDoubleStreamBlockProcessor. 2024-10-11 14:19:37 +00:00
Ryan Dick
b9d5ece22b Copy IPDoubleStreamBlockProcessor from 47495425db/src/flux/modules/layers.py (L221). 2024-10-11 14:11:11 +00:00
Ryan Dick
c01e2ca07e Add XLabs IP-Adapter state dict for unit tests. 2024-10-11 13:12:04 +00:00
17 changed files with 793 additions and 26 deletions

View File

@@ -1,15 +1,19 @@
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
import numpy as np
import numpy.typing as npt
import torch
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
ImageField,
Input,
InputField,
LatentsField,
@@ -17,6 +21,7 @@ from invokeai.app.invocations.fields import (
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -26,6 +31,8 @@ from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
@@ -106,6 +113,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
@@ -228,6 +239,14 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
# Compute the IP-Adapter image prompt clip embeddings.
# We do this before loading other models to minimize peak memory.
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
ip_adapter_fields = self._normalize_ip_adapter_fields()
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
ip_adapter_fields, context
)
with ExitStack() as exit_stack:
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
@@ -276,6 +295,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
else:
raise ValueError(f"Unsupported model format: {config.format}")
# Prepare IP-Adapter extensions.
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,
neg_image_prompt_clip_embeds=neg_image_prompt_clip_embeds,
ip_adapter_fields=ip_adapter_fields,
context=context,
exit_stack=exit_stack,
dtype=inference_dtype,
)
x = denoise(
model=transformer,
img=x,
@@ -292,6 +321,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
cfg_scale=self.cfg_scale,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
)
x = unpack(x.float(), self.height, self.width)
@@ -436,6 +467,107 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return controlnet_extensions
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
if self.ip_adapter is None:
return []
elif isinstance(self.ip_adapter, IPAdapterField):
return [self.ip_adapter]
elif isinstance(self.ip_adapter, list):
return self.ip_adapter
else:
raise ValueError(f"Unsupported IP-Adapter type: {type(self.ip_adapter)}")
def _prep_ip_adapter_image_prompt_clip_embeds(
self,
ip_adapter_fields: list[IPAdapterField],
context: InvocationContext,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
clip_image_processor = CLIPImageProcessor()
pos_image_prompt_clip_embeds: list[torch.Tensor] = []
neg_image_prompt_clip_embeds: list[torch.Tensor] = []
for ip_adapter_field in ip_adapter_fields:
# `ip_adapter_field.image` could be a list or a single ImageField. Normalize to a list here.
ipa_image_fields: list[ImageField]
if isinstance(ip_adapter_field.image, ImageField):
ipa_image_fields = [ip_adapter_field.image]
elif isinstance(ip_adapter_field.image, list):
ipa_image_fields = ip_adapter_field.image
else:
raise ValueError(f"Unsupported IP-Adapter image type: {type(ip_adapter_field.image)}")
ipa_images = [context.images.get_pil(image.image_name) for image in ipa_image_fields]
pos_images: list[npt.NDArray[np.uint8]] = []
neg_images: list[npt.NDArray[np.uint8]] = []
for ipa_image in ipa_images:
assert ipa_image.mode == "RGB"
pos_image = np.array(ipa_image)
# We use a black image as the negative image prompt for parity with
# https://github.com/XLabs-AI/x-flux-comfyui/blob/45c834727dd2141aebc505ae4b01f193a8414e38/nodes.py#L592-L593
# An alternative scheme would be to apply zeros_like() after calling the clip_image_processor.
neg_image = np.zeros_like(pos_image)
pos_images.append(pos_image)
neg_images.append(neg_image)
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)
neg_image_prompt_clip_embeds.append(neg_clip_image_embeds)
return pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds
def _prep_ip_adapter_extensions(
self,
ip_adapter_fields: list[IPAdapterField],
pos_image_prompt_clip_embeds: list[torch.Tensor],
neg_image_prompt_clip_embeds: list[torch.Tensor],
context: InvocationContext,
exit_stack: ExitStack,
dtype: torch.dtype,
) -> tuple[list[XLabsIPAdapterExtension], list[XLabsIPAdapterExtension]]:
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
if ip_adapter_field.mask is not None:
raise ValueError("IP-Adapter masks are not yet supported in Flux.")
ip_adapter_extension = XLabsIPAdapterExtension(
model=ip_adapter_model,
image_prompt_clip_embed=pos_image_prompt_clip_embed,
weight=ip_adapter_field.weight,
begin_step_percent=ip_adapter_field.begin_step_percent,
end_step_percent=ip_adapter_field.end_step_percent,
)
ip_adapter_extension.run_image_proj(dtype=dtype)
pos_ip_adapter_extensions.append(ip_adapter_extension)
ip_adapter_extension = XLabsIPAdapterExtension(
model=ip_adapter_model,
image_prompt_clip_embed=neg_image_prompt_clip_embed,
weight=ip_adapter_field.weight,
begin_step_percent=ip_adapter_field.begin_step_percent,
end_step_percent=ip_adapter_field.end_step_percent,
)
ip_adapter_extension.run_image_proj(dtype=dtype)
neg_ip_adapter_extensions.append(ip_adapter_extension)
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)

View File

@@ -9,6 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Outpu
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@@ -55,10 +56,14 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
CLIP_VISION_MODEL_MAP = {
"ViT-L": ("InvokeAI/clip-vit-large-patch14", "clip-vit-large-patch14-full"),
"ViT-H": ("InvokeAI/ip_adapter_sd_image_encoder", "ip_adapter_sd_image_encoder"),
"ViT-G": ("InvokeAI/ip_adapter_sdxl_image_encoder", "ip_adapter_sdxl_image_encoder"),
}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.5.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
@@ -70,7 +75,7 @@ class IPAdapterInvocation(BaseInvocation):
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
clip_vision_model: Literal["ViT-L", "ViT-H", "ViT-G"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
default="ViT-H",
ui_order=2,
@@ -111,9 +116,9 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
image_encoder_model_id, image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
image_encoder_model = self._get_image_encoder(context, image_encoder_model_id, image_encoder_model_name)
if self.method == "style":
if ip_adapter_info.base == "sd-1":
@@ -147,7 +152,9 @@ class IPAdapterInvocation(BaseInvocation):
),
)
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
def _get_image_encoder(
self, context: InvocationContext, image_encoder_model_id: str, image_encoder_model_name: str
) -> AnyModelConfig:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
@@ -159,7 +166,11 @@ class IPAdapterInvocation(BaseInvocation):
)
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
# Note: We hard-code the type to CLIPVision here because if the model contains both a CLIPVision and a
# CLIPText model, the probe may treat it as a CLIPText model.
job = installer.heuristic_import(
image_encoder_model_id, ModelRecordChanges(name=image_encoder_model_name, type=ModelType.CLIPVision)
)
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision

View File

@@ -40,7 +40,7 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
clip_vision_model: Literal["ViT-L", "ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")

View File

@@ -0,0 +1,83 @@
import einops
import torch
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.math import attention
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
class CustomDoubleStreamBlockProcessor:
"""A class containing a custom implementation of DoubleStreamBlock.forward() with additional features
(IP-Adapter, etc.).
"""
@staticmethod
def _double_stream_block_forward(
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
values.
"""
img_mod1, img_mod2 = block.img_mod(vec)
txt_mod1, txt_mod2 = block.txt_mod(vec)
# prepare image for attention
img_modulated = block.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = block.img_attn.qkv(img_modulated)
img_q, img_k, img_v = einops.rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
img_q, img_k = block.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = block.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = block.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = einops.rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
txt_q, txt_k = block.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * block.img_attn.proj(img_attn)
img = img + img_mod2.gate * block.img_mlp((1 + img_mod2.scale) * block.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * block.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * block.txt_mlp((1 + txt_mod2.scale) * block.txt_norm2(txt) + txt_mod2.shift)
return img, txt, img_q
@staticmethod
def custom_double_block_forward(
timestep_index: int,
total_num_timesteps: int,
block_index: int,
block: DoubleStreamBlock,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
) -> tuple[torch.Tensor, torch.Tensor]:
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
- IP-Adapter support
"""
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
# Apply IP-Adapter conditioning.
for ip_adapter_extension in ip_adapter_extensions:
img = ip_adapter_extension.run_ip_adapter(
timestep_index=timestep_index,
total_num_timesteps=total_num_timesteps,
block_index=block_index,
block=block,
img_q=img_q,
img=img,
)
return img, txt

View File

@@ -7,6 +7,7 @@ from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFl
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -31,6 +32,8 @@ def denoise(
cfg_scale: float,
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -49,12 +52,14 @@ def denoise(
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
timestep_index = step - 1
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step - 1,
timestep_index=timestep_index,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
@@ -80,25 +85,32 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=timestep_index,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
)
# TODO(ryand): Add option to apply controlnet to negative conditioning as well.
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on
# systems with sufficient VRAM.
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
)
pred = neg_pred + cfg_scale * (pred - neg_pred)
if step > 1:
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=timestep_index,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
)
pred = neg_pred + cfg_scale * (pred - neg_pred)
preview_img = img - t_curr * pred
img = img + (t_prev - t_curr) * pred

View File

@@ -0,0 +1,89 @@
import math
from typing import List, Union
import einops
import torch
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
class XLabsIPAdapterExtension:
def __init__(
self,
model: XlabsIpAdapterFlux,
image_prompt_clip_embed: torch.Tensor,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
self._model = model
self._image_prompt_clip_embed = image_prompt_clip_embed
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
self._image_proj: torch.Tensor | None = None
def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float:
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
if timestep_index < first_step or timestep_index > last_step:
return 0.0
if isinstance(self._weight, list):
return self._weight[timestep_index]
return self._weight
@staticmethod
def run_clip_image_encoder(
pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection
) -> torch.Tensor:
clip_image_processor = CLIPImageProcessor()
clip_image: torch.Tensor = clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder.device, dtype=image_encoder.dtype)
clip_image_embeds = image_encoder(clip_image).image_embeds
return clip_image_embeds
def run_image_proj(self, dtype: torch.dtype):
image_prompt_clip_embed = self._image_prompt_clip_embed.to(dtype=dtype)
self._image_proj = self._model.image_proj(image_prompt_clip_embed)
def run_ip_adapter(
self,
timestep_index: int,
total_num_timesteps: int,
block_index: int,
block: DoubleStreamBlock,
img_q: torch.Tensor,
img: torch.Tensor,
) -> torch.Tensor:
"""The logic in this function is based on:
https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/modules/layers.py#L245-L301
"""
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return img
ip_adapter_block = self._model.ip_adapter_double_blocks.double_blocks[block_index]
ip_key = ip_adapter_block.ip_adapter_double_stream_k_proj(self._image_proj)
ip_value = ip_adapter_block.ip_adapter_double_stream_v_proj(self._image_proj)
# Reshape projections for multi-head attention.
ip_key = einops.rearrange(ip_key, "B L (H D) -> B H L D", H=block.num_heads)
ip_value = einops.rearrange(ip_value, "B L (H D) -> B H L D", H=block.num_heads)
# Compute attention between IP projections and the latent query.
ip_attn = torch.nn.functional.scaled_dot_product_attention(
img_q, ip_key, ip_value, dropout_p=0.0, is_causal=False
)
ip_attn = einops.rearrange(ip_attn, "B H L D -> B L (H D)", H=block.num_heads)
img = img + weight * ip_attn
return img

View File

@@ -0,0 +1,93 @@
# This file is based on:
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/modules/layers.py#L221
import einops
import torch
from invokeai.backend.flux.math import attention
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
class IPDoubleStreamBlockProcessor(torch.nn.Module):
"""Attention processor for handling IP-adapter with double stream block."""
def __init__(self, context_dim: int, hidden_dim: int):
super().__init__()
# Ensure context_dim matches the dimension of image_proj
self.context_dim = context_dim
self.hidden_dim = hidden_dim
# Initialize projections for IP-adapter
self.ip_adapter_double_stream_k_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
self.ip_adapter_double_stream_v_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
torch.nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
torch.nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
torch.nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
torch.nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
def __call__(
self,
attn: DoubleStreamBlock,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
image_proj: torch.Tensor,
ip_scale: float = 1.0,
):
# Prepare image for attention
img_mod1, img_mod2 = attn.img_mod(vec)
txt_mod1, txt_mod2 = attn.txt_mod(vec)
img_modulated = attn.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = attn.img_attn.qkv(img_modulated)
img_q, img_k, img_v = einops.rearrange(
img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim
)
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
txt_modulated = attn.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = attn.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = einops.rearrange(
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim
)
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn1 = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
# print(f"txt_attn shape: {txt_attn.size()}")
# print(f"img_attn shape: {img_attn.size()}")
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
# IP-adapter processing
ip_query = img_q # latent sample query
ip_key = self.ip_adapter_double_stream_k_proj(image_proj)
ip_value = self.ip_adapter_double_stream_v_proj(image_proj)
# Reshape projections for multi-head attention
ip_key = einops.rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
ip_value = einops.rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
# Compute attention between IP projections and the latent query
ip_attention = torch.nn.functional.scaled_dot_product_attention(
ip_query, ip_key, ip_value, dropout_p=0.0, is_causal=False
)
ip_attention = einops.rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim)
img = img + ip_scale * ip_attention
return img, txt

View File

@@ -0,0 +1,50 @@
from typing import Any, Dict
import torch
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterParams
def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an XLabs FLUX IP-Adapter model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely an XLabs IP-Adapter model.
expected_keys = {
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias",
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight",
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias",
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.weight",
"ip_adapter_proj_model.norm.bias",
"ip_adapter_proj_model.norm.weight",
"ip_adapter_proj_model.proj.bias",
"ip_adapter_proj_model.proj.weight",
}
if expected_keys.issubset(sd.keys()):
return True
return False
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams:
num_double_blocks = 0
context_dim = 0
hidden_dim = 0
# Count the number of double blocks.
double_block_index = 0
while f"double_blocks.{double_block_index}.processor.ip_adapter_double_stream_k_proj.weight" in state_dict:
double_block_index += 1
num_double_blocks = double_block_index
hidden_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[0]
context_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[1]
clip_embeddings_dim = state_dict["ip_adapter_proj_model.proj.weight"].shape[1]
return XlabsIpAdapterParams(
num_double_blocks=num_double_blocks,
context_dim=context_dim,
hidden_dim=hidden_dim,
clip_embeddings_dim=clip_embeddings_dim,
)

View File

@@ -0,0 +1,67 @@
from dataclasses import dataclass
import torch
from invokeai.backend.ip_adapter.ip_adapter import ImageProjModel
class IPDoubleStreamBlock(torch.nn.Module):
def __init__(self, context_dim: int, hidden_dim: int):
super().__init__()
self.context_dim = context_dim
self.hidden_dim = hidden_dim
self.ip_adapter_double_stream_k_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
self.ip_adapter_double_stream_v_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
class IPAdapterDoubleBlocks(torch.nn.Module):
def __init__(self, num_double_blocks: int, context_dim: int, hidden_dim: int):
super().__init__()
self.double_blocks = torch.nn.ModuleList(
[IPDoubleStreamBlock(context_dim, hidden_dim) for _ in range(num_double_blocks)]
)
@dataclass
class XlabsIpAdapterParams:
num_double_blocks: int
context_dim: int
hidden_dim: int
clip_embeddings_dim: int
class XlabsIpAdapterFlux(torch.nn.Module):
def __init__(self, params: XlabsIpAdapterParams):
super().__init__()
self.image_proj = ImageProjModel(
cross_attention_dim=params.context_dim, clip_embeddings_dim=params.clip_embeddings_dim
)
self.ip_adapter_double_blocks = IPAdapterDoubleBlocks(
num_double_blocks=params.num_double_blocks, context_dim=params.context_dim, hidden_dim=params.hidden_dim
)
def load_xlabs_state_dict(self, state_dict: dict[str, torch.Tensor], assign: bool = False):
"""We need this custom function to load state dicts rather than using .load_state_dict(...) because the model
structure does not match the state_dict structure.
"""
# Split the state_dict into the image projection model and the double blocks.
image_proj_sd: dict[str, torch.Tensor] = {}
double_blocks_sd: dict[str, torch.Tensor] = {}
for k, v in state_dict.items():
if k.startswith("ip_adapter_proj_model."):
image_proj_sd[k] = v
elif k.startswith("double_blocks."):
double_blocks_sd[k] = v
else:
raise ValueError(f"Unexpected key: {k}")
# Initialize the image projection model.
image_proj_sd = {k.replace("ip_adapter_proj_model.", ""): v for k, v in image_proj_sd.items()}
self.image_proj.load_state_dict(image_proj_sd, assign=assign)
# Initialize the double blocks.
double_blocks_sd = {k.replace("processor.", ""): v for k, v in double_blocks_sd.items()}
self.ip_adapter_double_blocks.load_state_dict(double_blocks_sd, assign=assign)

View File

@@ -5,6 +5,8 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
@@ -88,8 +90,11 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None,
timestep_index: int,
total_num_timesteps: int,
controlnet_double_block_residuals: list[Tensor] | None,
controlnet_single_block_residuals: list[Tensor] | None,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -111,7 +116,19 @@ class Flux(nn.Module):
if controlnet_double_block_residuals is not None:
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
for block_index, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
assert isinstance(block, DoubleStreamBlock)
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
timestep_index=timestep_index,
total_num_timesteps=total_num_timesteps,
block_index=block_index,
block=block,
img=img,
txt=txt,
vec=vec,
pe=pe,
ip_adapter_extensions=ip_adapter_extensions,
)
if controlnet_double_block_residuals is not None:
img += controlnet_double_block_residuals[block_index]

View File

@@ -0,0 +1,41 @@
from pathlib import Path
from typing import Optional
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
DiffusersConfigBase,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
class ClipVisionLoader(ModelLoader):
"""Class to load CLIPVision models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, DiffusersConfigBase):
raise ValueError("Only DiffusersConfigBase models are currently supported here.")
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
model_path = Path(config.path)
model = CLIPVisionModelWithProjection.from_pretrained(
model_path, torch_dtype=self._torch_dtype, local_files_only=True
)
assert isinstance(model, CLIPVisionModelWithProjection)
return model

View File

@@ -19,6 +19,10 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.ip_adapter.state_dict_utils import infer_xlabs_ip_adapter_params_from_state_dict
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import (
XlabsIpAdapterFlux,
)
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, params
@@ -35,6 +39,7 @@ from invokeai.backend.model_manager.config import (
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
IPAdapterCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
@@ -352,3 +357,26 @@ class FluxControlnetModel(ModelLoader):
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
class FluxIpAdapterModel(ModelLoader):
"""Class to load FLUX IP-Adapter models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, IPAdapterCheckpointConfig):
raise ValueError(f"Unexpected model config type: {type(config)}.")
sd = load_file(Path(config.path))
params = infer_xlabs_ip_adapter_params_from_state_dict(sd)
with accelerate.init_empty_weights():
model = XlabsIpAdapterFlux(params=params)
model.load_xlabs_state_dict(sd, assign=True)
return model

View File

@@ -22,7 +22,6 @@ from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models."""

View File

@@ -14,6 +14,7 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
@@ -243,8 +244,6 @@ class ModelProbe(object):
"cond_stage_model.",
"first_stage_model.",
"model.diffusion_model.",
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix.
"double_blocks.",
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
# This prefix is typically used to distinguish between multiple models bundled in a single file.
"model.diffusion_model.double_blocks.",
@@ -252,6 +251,10 @@ class ModelProbe(object):
):
# Keys starting with double_blocks are associated with Flux models
return ModelType.Main
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
# careful to avoid false positives on XLabs FLUX IP-Adapter models.
elif key.startswith("double_blocks.") and "ip_adapter" not in key:
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
@@ -274,7 +277,14 @@ class ModelProbe(object):
)
):
return ModelType.ControlNet
elif key.startswith(("image_proj.", "ip_adapter.")):
elif key.startswith(
(
"image_proj.",
"ip_adapter.",
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
"ip_adapter_proj_model.",
)
):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
@@ -672,6 +682,10 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if is_state_dict_xlabs_ip_adapter(checkpoint):
return BaseModelType.Flux
for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")):
continue

View File

@@ -0,0 +1,46 @@
import accelerate
import torch
from invokeai.backend.flux.ip_adapter.state_dict_utils import (
infer_xlabs_ip_adapter_params_from_state_dict,
is_state_dict_xlabs_ip_adapter,
)
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import (
XlabsIpAdapterFlux,
)
from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_state_dict import xlabs_sd_shapes
def test_is_state_dict_xlabs_ip_adapter():
# Construct a dummy state_dict.
sd = {k: None for k in xlabs_sd_shapes}
assert is_state_dict_xlabs_ip_adapter(sd)
def test_infer_xlabs_ip_adapter_params_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()}
params = infer_xlabs_ip_adapter_params_from_state_dict(sd)
assert params.num_double_blocks == 19
assert params.context_dim == 4096
assert params.hidden_dim == 3072
assert params.clip_embeddings_dim == 768
def test_initialize_xlabs_ip_adapter_flux_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()}
# Initialize the XLabs IP-Adapter from the state_dict.
params = infer_xlabs_ip_adapter_params_from_state_dict(sd)
with accelerate.init_empty_weights():
model = XlabsIpAdapterFlux(params=params)
# Smoke test state_dict loading.
model.load_xlabs_state_dict(sd)

View File

@@ -0,0 +1,85 @@
# State dict keys and shapes for an XLabs FLUX IP-Adapter model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-ip-adapter/blob/ad16be50d78a07ea83d8c4bde44ff9753235182e/flux-ip-adapter.safetensors
xlabs_sd_shapes = {
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.1.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.1.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.1.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.1.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.10.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.10.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.10.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.10.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.11.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.11.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.11.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.11.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.12.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.12.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.12.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.12.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.13.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.13.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.13.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.13.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.14.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.14.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.14.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.14.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.15.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.15.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.15.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.15.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.16.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.16.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.16.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.16.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.17.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.17.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.17.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.17.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.18.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.18.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.18.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.18.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.2.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.2.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.2.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.2.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.3.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.3.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.3.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.3.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.4.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.4.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.4.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.4.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.5.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.5.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.5.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.5.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.6.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.6.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.6.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.6.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.7.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.7.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.7.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.7.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.8.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.8.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.8.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.8.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.9.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.9.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.9.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.9.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"ip_adapter_proj_model.norm.bias": [4096],
"ip_adapter_proj_model.norm.weight": [4096],
"ip_adapter_proj_model.proj.bias": [16384],
"ip_adapter_proj_model.proj.weight": [16384, 768],
}