From b2d5b53b5ff31385e44dfd2d7e84997f4834eef5 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Sep 2023 11:47:36 -0400 Subject: [PATCH] Pass IP-Adapter conditioning via cross_attention_kwargs instead of concatenating to the text embedding. This avoids interference with other features that manipulate the text embedding (e.g. long prompts). --- .../backend/ip_adapter/attention_processor.py | 120 +++++++++++------- invokeai/backend/ip_adapter/ip_adapter.py | 1 - .../stable_diffusion/diffusers_pipeline.py | 25 +--- .../diffusion/conditioning_data.py | 14 ++ .../diffusion/shared_invokeai_diffusion.py | 43 ++++++- 5 files changed, 135 insertions(+), 68 deletions(-) diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py index 99d9edc5dd..b2876620cd 100644 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -19,12 +19,42 @@ class AttnProcessor(DiffusersAttnProcessor, nn.Module): DiffusersAttnProcessor.__init__(self) nn.Module.__init__(self) + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ip_adapter_image_prompt_embeds=None, + ): + """Re-definition of DiffusersAttnProcessor.__call__(...) that accepts and ignores the + ip_adapter_image_prompt_embeds parameter. + """ + return DiffusersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb) + class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module): def __init__(self): DiffusersAttnProcessor2_0.__init__(self) nn.Module.__init__(self) + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ip_adapter_image_prompt_embeds=None, + ): + """Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the + ip_adapter_image_prompt_embeds parameter. + """ + return DiffusersAttnProcessor2_0.__call__( + self, attn, hidden_states, encoder_hidden_states, attention_mask, temb + ) + class IPAttnProcessor(nn.Module): r""" @@ -32,21 +62,17 @@ class IPAttnProcessor(nn.Module): Args: hidden_size (`int`): The hidden size of the attention layer. - image_embedding_len (`int`): - The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of - the `encoder_hidden_states` are the IP-Adapter image embeddings. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. """ - def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0): + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim - self.image_embedding_len = image_embedding_len self.scale = scale self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) @@ -59,7 +85,18 @@ class IPAttnProcessor(nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + ip_adapter_image_prompt_embeds=None, ): + if encoder_hidden_states is not None: + # If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case, + # we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here. + assert ip_adapter_image_prompt_embeds is not None + # The batch dimensions should match. + assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0] + # The channel dimensions should match. + assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2] + ip_hidden_states = ip_adapter_image_prompt_embeds + residual = hidden_states if attn.spatial_norm is not None: @@ -86,12 +123,6 @@ class IPAttnProcessor(nn.Module): elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - # Split text encoder hidden states and image encoder hidden state. - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, : -self.image_embedding_len, :], - encoder_hidden_states[:, -self.image_embedding_len :, :], - ) - key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -103,18 +134,18 @@ class IPAttnProcessor(nn.Module): hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) + if ip_hidden_states is not None: + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - hidden_states = hidden_states + self.scale * ip_hidden_states + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -138,16 +169,13 @@ class IPAttnProcessor2_0(torch.nn.Module): Args: hidden_size (`int`): The hidden size of the attention layer. - image_embedding_len (`int`): - The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of - the `encoder_hidden_states` are the IP-Adapter image embeddings. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. """ - def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0): + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -155,7 +183,6 @@ class IPAttnProcessor2_0(torch.nn.Module): self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim - self.text_context_len = text_context_len self.scale = scale self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) @@ -168,7 +195,18 @@ class IPAttnProcessor2_0(torch.nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + ip_adapter_image_prompt_embeds=None, ): + if encoder_hidden_states is not None: + # If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case, + # we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here. + assert ip_adapter_image_prompt_embeds is not None + # The batch dimensions should match. + assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0] + # The channel dimensions should match. + assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2] + ip_hidden_states = ip_adapter_image_prompt_embeds + residual = hidden_states if attn.spatial_norm is not None: @@ -200,12 +238,6 @@ class IPAttnProcessor2_0(torch.nn.Module): elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - # Split text encoder hidden states and image encoder hidden state. - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, : -self.image_embedding_len, :], - encoder_hidden_states[:, -self.image_embedding_len :, :], - ) - key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -226,23 +258,23 @@ class IPAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) + if ip_hidden_states: + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) - hidden_states = hidden_states + self.scale * ip_hidden_states + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index a9fcc25539..a109c9aac3 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -92,7 +92,6 @@ class IPAdapter: print("swapping in IPAttnProcessor for", name) attn_procs[name] = IPAttnProcessor( hidden_size=hidden_size, - image_embedding_len=self.num_tokens, cross_attention_dim=cross_attention_dim, scale=1.0, ).to(self.device, dtype=torch.float16) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 17b358f61c..138e3c9cea 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -30,6 +30,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ConditioningData, + IPAdapterConditioningInfo, ) from ..util import auto_detect_slice_size, normalize_device @@ -449,27 +450,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # Get image embeddings from CLIP and ImageProjModel. image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image) + conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo( + image_prompt_embeds, uncond_image_prompt_embeds + ) - # The following commented block is kept for reference on how to repeat/reshape the image embeddings to - # generate a batch of multiple images: - # bs_embed, seq_len, _ = image_prompt_embeds.shape - # image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) - # image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - # uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) - # uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - - text_prompt_embeds = conditioning_data.text_embeddings.embeds - uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds - print("text embeds shape:", text_prompt_embeds.shape) - concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1) - concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1) - print("concat embeds shape:", concat_prompt_embeds.shape) - conditioning_data.text_embeddings.embeds = concat_prompt_embeds - conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds - else: - image_prompt_embeds = None - uncond_image_prompt_embeds = None - + # TODO(ryand): Apply IP-Adapter or custom attention control extra_conditioning_info = conditioning_data.extra with self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 2634dd600d..083d925899 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -51,6 +51,18 @@ class PostprocessingSettings: v_symmetry_time_pct: Optional[float] +@dataclass +class IPAdapterConditioningInfo: + cond_image_prompt_embeds: torch.Tensor + """IP-Adapter image encoder conditioning embeddings. + Shape: (batch_size, num_tokens, encoding_dim). Typically: (1, 4, 1024) TODO(ryand): confirm + """ + uncond_image_prompt_embeds: torch.Tensor + """IP-Adapter image encoding embeddings to use for unconditional generation. + Shape: (batch_size, num_tokens, encoding_dim). Typically: (1, 4, 1024) TODO(ryand): confirm + """ + + @dataclass class ConditioningData: unconditioned_embeddings: BasicConditioningInfo @@ -69,6 +81,8 @@ class ConditioningData: """ postprocessing_settings: Optional[PostprocessingSettings] = None + ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None + @property def dtype(self): return self.text_embeddings.dtype diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index aa53a0435e..8473fa7bcc 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -10,6 +10,7 @@ from typing_extensions import TypeAlias from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + ConditioningData, ExtraConditioningInfo, PostprocessingSettings, SDXLConditioningInfo, @@ -232,6 +233,8 @@ class InvokeAIDiffuserComponent: total_step_count: int, **kwargs, ): + # TODO(ryand): Raise here if both cross attention control and ip-adapter are enabled? + cross_attention_control_types_to_do = [] context: Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: @@ -339,11 +342,24 @@ class InvokeAIDiffuserComponent: # methods below are called from do_diffusion_step and should be considered private to this class. - def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs): - # fast batched path + def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs): + """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at + the cost of higher memory usage. + """ x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) + cross_attention_kwargs = None + if conditioning_data.ip_adapter_conditioning is not None: + cross_attention_kwargs = { + "ip_adapter_image_prompt_embeds": torch.cat( + [ + conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds, + conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds, + ] + ) + } + added_cond_kwargs = None if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { @@ -371,6 +387,7 @@ class InvokeAIDiffuserComponent: x_twice, sigma_twice, both_conditionings, + cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, added_cond_kwargs=added_cond_kwargs, **kwargs, @@ -382,9 +399,12 @@ class InvokeAIDiffuserComponent: self, x: torch.Tensor, sigma, - conditioning_data, + conditioning_data: ConditioningData, **kwargs, ): + """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of + slower execution speed. + """ # low-memory sequential path uncond_down_block, cond_down_block = None, None down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None) @@ -400,6 +420,13 @@ class InvokeAIDiffuserComponent: if mid_block_additional_residual is not None: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) + # Run unconditional UNet denoising. + cross_attention_kwargs = None + if conditioning_data.ip_adapter_conditioning is not None: + cross_attention_kwargs = { + "ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds + } + added_cond_kwargs = None is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo if is_sdxl: @@ -412,12 +439,21 @@ class InvokeAIDiffuserComponent: x, sigma, conditioning_data.unconditioned_embeddings.embeds, + cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, added_cond_kwargs=added_cond_kwargs, **kwargs, ) + # Run conditional UNet denoising. + cross_attention_kwargs = None + if conditioning_data.ip_adapter_conditioning is not None: + cross_attention_kwargs = { + "ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds + } + + added_cond_kwargs = None if is_sdxl: added_cond_kwargs = { "text_embeds": conditioning_data.text_embeddings.pooled_embeds, @@ -428,6 +464,7 @@ class InvokeAIDiffuserComponent: x, sigma, conditioning_data.text_embeddings.embeds, + cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, added_cond_kwargs=added_cond_kwargs,