From d27907cc6d4f7939ac796dd1c2fdffb1652dc5bf Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Tue, 16 Apr 2024 04:29:53 +0530 Subject: [PATCH] fix: entire reshaping block needs to be skipped --- .../diffusion/custom_atttention.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 8d7245ae3b..16617d049a 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -178,15 +178,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) - # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) + # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) - ip_hidden_states = ip_hidden_states.to(query.dtype) + ip_hidden_states = ip_hidden_states.to(query.dtype) - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask + # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) + hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask else: # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. assert regional_ip_data is None