diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 8d7245ae3b..16617d049a 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -178,15 +178,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) - # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) + # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) - ip_hidden_states = ip_hidden_states.to(query.dtype) + ip_hidden_states = ip_hidden_states.to(query.dtype) - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask + # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) + hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask else: # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. assert regional_ip_data is None