Fixes to get XLabsIpAdapterExtension running.

This commit is contained in:
Ryan Dick
2024-10-16 01:39:48 +00:00
parent 31ffd73423
commit fdccdd52d5

View File

@@ -75,14 +75,14 @@ class XLabsIPAdapterExtension:
ip_value = ip_adapter_block.ip_adapter_double_stream_v_proj(self._image_proj)
# Reshape projections for multi-head attention.
ip_key = einops.rearrange(ip_key, "B L (H D) -> B H L D", H=block.num_heads, D=block.head_dim)
ip_value = einops.rearrange(ip_value, "B L (H D) -> B H L D", H=block.num_heads, D=block.head_dim)
ip_key = einops.rearrange(ip_key, "B L (H D) -> B H L D", H=block.num_heads)
ip_value = einops.rearrange(ip_value, "B L (H D) -> B H L D", H=block.num_heads)
# Compute attention between IP projections and the latent query.
ip_attn = torch.nn.functional.scaled_dot_product_attention(
img_q, ip_key, ip_value, dropout_p=0.0, is_causal=False
)
ip_attn = einops.rearrange(ip_attn, "B H L D -> B L (H D)", H=block.num_heads, D=block.head_dim)
ip_attn = einops.rearrange(ip_attn, "B H L D -> B L (H D)", H=block.num_heads)
img = img + weight * ip_attn