From 31ffd734233c6a917b33dbdc95adb736a854bbcb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 15 Oct 2024 22:28:59 +0000 Subject: [PATCH] Initial draft of integrating FLUX IP-Adapter inference support. --- invokeai/app/invocations/flux_denoise.py | 92 +++++++++++++++++++ .../backend/flux/custom_block_processor.py | 83 +++++++++++++++++ invokeai/backend/flux/denoise.py | 5 + .../extensions/xlabs_ip_adapter_extension.py | 89 ++++++++++++++++++ invokeai/backend/flux/model.py | 19 +++- 5 files changed, 287 insertions(+), 1 deletion(-) create mode 100644 invokeai/backend/flux/custom_block_processor.py create mode 100644 invokeai/backend/flux/extensions/xlabs_ip_adapter_extension.py diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 8120ac400f..1b7dea7b60 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -4,12 +4,14 @@ from typing import Callable, Iterator, Optional, Tuple import torch import torchvision.transforms as tv_transforms from torchvision.transforms.functional import resize as tv_resize +from transformers import CLIPVisionModelWithProjection from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.fields import ( DenoiseMaskField, FieldDescriptions, FluxConditioningField, + ImageField, Input, InputField, LatentsField, @@ -17,6 +19,7 @@ from invokeai.app.invocations.fields import ( WithMetadata, ) from invokeai.app.invocations.flux_controlnet import FluxControlNetField +from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.model import TransformerField, VAEField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext @@ -26,6 +29,8 @@ from invokeai.backend.flux.denoise import denoise from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension +from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension +from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux from invokeai.backend.flux.model import Flux from invokeai.backend.flux.sampling_utils import ( clip_timestep_schedule_fractional, @@ -118,6 +123,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): input=Input.Connection, ) + ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField( + description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection + ) + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self._run_diffusion(context) @@ -245,6 +254,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): noise=noise, ) + # Compute the IP-Adapter image prompt clip embeddings. + # We do this before loading other models to minimize peak memory. + # TODO(ryand): We should really do this in a separate invocation to benefit from caching. + ip_adapter_fields = self._normalize_ip_adapter_fields() + image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(ip_adapter_fields, context) + cfg_scale = self.prep_cfg_scale( cfg_scale=self.cfg_scale, timesteps=timesteps, @@ -300,6 +315,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): else: raise ValueError(f"Unsupported model format: {config.format}") + # Prepare IP-Adapter extensions. + ip_adapter_extensions = self._prep_ip_adapter_extensions( + image_prompt_clip_embeds=image_prompt_clip_embeds, + ip_adapter_fields=ip_adapter_fields, + context=context, + exit_stack=exit_stack, + dtype=inference_dtype, + ) + x = denoise( model=transformer, img=x, @@ -316,6 +340,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): cfg_scale=cfg_scale, inpaint_extension=inpaint_extension, controlnet_extensions=controlnet_extensions, + ip_adapter_extensions=ip_adapter_extensions, ) x = unpack(x.float(), self.height, self.width) @@ -509,6 +534,73 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): return controlnet_extensions + def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]: + if self.ip_adapter is None: + return [] + elif isinstance(self.ip_adapter, IPAdapterField): + return [self.ip_adapter] + elif isinstance(self.ip_adapter, list): + return self.ip_adapter + else: + raise ValueError(f"Unsupported IP-Adapter type: {type(self.ip_adapter)}") + + def _prep_ip_adapter_image_prompt_clip_embeds( + self, + ip_adapter_fields: list[IPAdapterField], + context: InvocationContext, + ) -> list[torch.Tensor]: + """Run the IPAdapter CLIPVisionModel, returning image prompt embeddings.""" + image_prompt_clip_embeds: list[torch.Tensor] = [] + for ip_adapter_field in ip_adapter_fields: + # `ip_adapter_field.image` could be a list or a single ImageField. Normalize to a list here. + ipa_image_fields: list[ImageField] + if isinstance(ip_adapter_field.image, ImageField): + ipa_image_fields = [ip_adapter_field.image] + elif isinstance(ip_adapter_field.image, list): + ipa_image_fields = ip_adapter_field.image + else: + raise ValueError(f"Unsupported IP-Adapter image type: {type(ip_adapter_field.image)}") + + ipa_images = [context.images.get_pil(image.image_name) for image in ipa_image_fields] + + with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model: + assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) + image_prompt_clip_embeds.append( + XLabsIPAdapterExtension.run_clip_image_encoder( + pil_image=ipa_images, + image_encoder=image_encoder_model, + ) + ) + return image_prompt_clip_embeds + + def _prep_ip_adapter_extensions( + self, + ip_adapter_fields: list[IPAdapterField], + image_prompt_clip_embeds: list[torch.Tensor], + context: InvocationContext, + exit_stack: ExitStack, + dtype: torch.dtype, + ) -> list[XLabsIPAdapterExtension]: + ip_adapter_extensions: list[XLabsIPAdapterExtension] = [] + for ip_adapter_field, image_prompt_clip_embed in zip(ip_adapter_fields, image_prompt_clip_embeds, strict=True): + ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model)) + assert isinstance(ip_adapter_model, XlabsIpAdapterFlux) + ip_adapter_model = ip_adapter_model.to(dtype=dtype) + if ip_adapter_field.mask is not None: + raise ValueError("IP-Adapter masks are not yet supported in Flux.") + ip_adapter_extension = XLabsIPAdapterExtension( + model=ip_adapter_model, + image_prompt_clip_embed=image_prompt_clip_embed, + weight=ip_adapter_field.weight, + begin_step_percent=ip_adapter_field.begin_step_percent, + end_step_percent=ip_adapter_field.end_step_percent, + ) + + ip_adapter_extension.run_image_proj(dtype=dtype) + ip_adapter_extensions.append(ip_adapter_extension) + + return ip_adapter_extensions + def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.transformer.loras: lora_info = context.models.load(lora.lora) diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py new file mode 100644 index 0000000000..e0c7779e93 --- /dev/null +++ b/invokeai/backend/flux/custom_block_processor.py @@ -0,0 +1,83 @@ +import einops +import torch + +from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension +from invokeai.backend.flux.math import attention +from invokeai.backend.flux.modules.layers import DoubleStreamBlock + + +class CustomDoubleStreamBlockProcessor: + """A class containing a custom implementation of DoubleStreamBlock.forward() with additional features + (IP-Adapter, etc.). + """ + + @staticmethod + def _double_stream_block_forward( + block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate + values. + """ + img_mod1, img_mod2 = block.img_mod(vec) + txt_mod1, txt_mod2 = block.txt_mod(vec) + + # prepare image for attention + img_modulated = block.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = block.img_attn.qkv(img_modulated) + img_q, img_k, img_v = einops.rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads) + img_q, img_k = block.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = block.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = block.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = einops.rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads) + txt_q, txt_k = block.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * block.img_attn.proj(img_attn) + img = img + img_mod2.gate * block.img_mlp((1 + img_mod2.scale) * block.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * block.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * block.txt_mlp((1 + txt_mod2.scale) * block.txt_norm2(txt) + txt_mod2.shift) + return img, txt, img_q + + @staticmethod + def custom_double_block_forward( + timestep_index: int, + total_num_timesteps: int, + block_index: int, + block: DoubleStreamBlock, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + pe: torch.Tensor, + ip_adapter_extensions: list[XLabsIPAdapterExtension], + ) -> tuple[torch.Tensor, torch.Tensor]: + """A custom implementation of DoubleStreamBlock.forward() with additional features: + - IP-Adapter support + """ + img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe) + + # Apply IP-Adapter conditioning. + for ip_adapter_extension in ip_adapter_extensions: + img = ip_adapter_extension.run_ip_adapter( + timestep_index=timestep_index, + total_num_timesteps=total_num_timesteps, + block_index=block_index, + block=block, + img_q=img_q, + img=img, + ) + + return img, txt diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 7ce375f4a2..025586f4e0 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -8,6 +8,7 @@ from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFl from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension +from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.model import Flux from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -32,6 +33,7 @@ def denoise( cfg_scale: list[float], inpaint_extension: InpaintExtension | None, controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension], + ip_adapter_extensions: list[XLabsIPAdapterExtension], ): # step 0 is the initial state total_steps = len(timesteps) - 1 @@ -80,8 +82,11 @@ def denoise( y=vec, timesteps=t_vec, guidance=guidance_vec, + timestep_index=step_index, + total_num_timesteps=total_steps, controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals, controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, + ip_adapter_extensions=ip_adapter_extensions, ) step_cfg_scale = cfg_scale[step_index] diff --git a/invokeai/backend/flux/extensions/xlabs_ip_adapter_extension.py b/invokeai/backend/flux/extensions/xlabs_ip_adapter_extension.py new file mode 100644 index 0000000000..13ebb1451f --- /dev/null +++ b/invokeai/backend/flux/extensions/xlabs_ip_adapter_extension.py @@ -0,0 +1,89 @@ +import math +from typing import List, Union + +import einops +import torch +from PIL import Image +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux +from invokeai.backend.flux.modules.layers import DoubleStreamBlock + + +class XLabsIPAdapterExtension: + def __init__( + self, + model: XlabsIpAdapterFlux, + image_prompt_clip_embed: torch.Tensor, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + ): + self._model = model + self._image_prompt_clip_embed = image_prompt_clip_embed + self._weight = weight + self._begin_step_percent = begin_step_percent + self._end_step_percent = end_step_percent + + self._image_proj: torch.Tensor | None = None + + def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float: + first_step = math.floor(self._begin_step_percent * total_num_timesteps) + last_step = math.ceil(self._end_step_percent * total_num_timesteps) + + if timestep_index < first_step or timestep_index > last_step: + return 0.0 + + if isinstance(self._weight, list): + return self._weight[timestep_index] + + return self._weight + + @staticmethod + def run_clip_image_encoder( + pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection + ) -> torch.Tensor: + clip_image_processor = CLIPImageProcessor() + clip_image: torch.Tensor = clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(device=image_encoder.device, dtype=image_encoder.dtype) + clip_image_embeds = image_encoder(clip_image).image_embeds + return clip_image_embeds + + def run_image_proj(self, dtype: torch.dtype): + image_prompt_clip_embed = self._image_prompt_clip_embed.to(dtype=dtype) + self._image_proj = self._model.image_proj(image_prompt_clip_embed) + + def run_ip_adapter( + self, + timestep_index: int, + total_num_timesteps: int, + block_index: int, + block: DoubleStreamBlock, + img_q: torch.Tensor, + img: torch.Tensor, + ) -> torch.Tensor: + """The logic in this function is based on: + https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/modules/layers.py#L245-L301 + """ + weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps) + if weight < 1e-6: + return img + + ip_adapter_block = self._model.ip_adapter_double_blocks.double_blocks[block_index] + + ip_key = ip_adapter_block.ip_adapter_double_stream_k_proj(self._image_proj) + ip_value = ip_adapter_block.ip_adapter_double_stream_v_proj(self._image_proj) + + # Reshape projections for multi-head attention. + ip_key = einops.rearrange(ip_key, "B L (H D) -> B H L D", H=block.num_heads, D=block.head_dim) + ip_value = einops.rearrange(ip_value, "B L (H D) -> B H L D", H=block.num_heads, D=block.head_dim) + + # Compute attention between IP projections and the latent query. + ip_attn = torch.nn.functional.scaled_dot_product_attention( + img_q, ip_key, ip_value, dropout_p=0.0, is_causal=False + ) + ip_attn = einops.rearrange(ip_attn, "B H L D -> B L (H D)", H=block.num_heads, D=block.head_dim) + + img = img + weight * ip_attn + + return img diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py index 3ec4c3922a..0dadacd8fe 100644 --- a/invokeai/backend/flux/model.py +++ b/invokeai/backend/flux/model.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch import Tensor, nn +from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor +from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.modules.layers import ( DoubleStreamBlock, EmbedND, @@ -88,8 +90,11 @@ class Flux(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor | None, + timestep_index: int, + total_num_timesteps: int, controlnet_double_block_residuals: list[Tensor] | None, controlnet_single_block_residuals: list[Tensor] | None, + ip_adapter_extensions: list[XLabsIPAdapterExtension], ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -111,7 +116,19 @@ class Flux(nn.Module): if controlnet_double_block_residuals is not None: assert len(controlnet_double_block_residuals) == len(self.double_blocks) for block_index, block in enumerate(self.double_blocks): - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + assert isinstance(block, DoubleStreamBlock) + + img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward( + timestep_index=timestep_index, + total_num_timesteps=total_num_timesteps, + block_index=block_index, + block=block, + img=img, + txt=txt, + vec=vec, + pe=pe, + ip_adapter_extensions=ip_adapter_extensions, + ) if controlnet_double_block_residuals is not None: img += controlnet_double_block_residuals[block_index]