mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 21:18:00 -05:00
WIP - add rough logic for preparing the FLUX regional prompting attention mask.
This commit is contained in:
@@ -10,7 +10,10 @@ from invokeai.backend.util.mask import to_standard_float_mask
|
||||
|
||||
|
||||
class RegionalPromptingExtension:
|
||||
"""A class for managing regional prompting with FLUX."""
|
||||
"""A class for managing regional prompting with FLUX.
|
||||
|
||||
Implementation inspired by: https://arxiv.org/pdf/2411.02395
|
||||
"""
|
||||
|
||||
def __init__(self, regional_text_conditioning: FluxRegionalTextConditioning):
|
||||
self.regional_text_conditioning = regional_text_conditioning
|
||||
@@ -19,6 +22,51 @@ class RegionalPromptingExtension:
|
||||
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning]):
|
||||
return cls(regional_text_conditioning=cls._concat_regional_text_conditioning(text_conditioning))
|
||||
|
||||
def _prepare_attn_mask(self) -> torch.Tensor:
|
||||
device = self.regional_text_conditioning.image_masks[0].device
|
||||
# img_seq_len = latent_height * latent_width
|
||||
img_seq_len = (
|
||||
self.regional_text_conditioning.image_masks.shape[-1]
|
||||
* self.regional_text_conditioning.image_masks.shape[-2]
|
||||
)
|
||||
txt_seq_len = self.regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# 1. txt attends to itself
|
||||
# 2. txt attends to corresponding regional img
|
||||
# 3. regional img attends to corresponding txt
|
||||
# 4. regional img attends to itself
|
||||
|
||||
# Initialize empty attention mask.
|
||||
regional_attention_mask = torch.zeros(
|
||||
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.bool
|
||||
)
|
||||
|
||||
for i in range(len(self.regional_text_conditioning.t5_embeddings)):
|
||||
image_mask = self.regional_text_conditioning.image_masks[i].flatten()
|
||||
t5_embedding_range = self.regional_text_conditioning.t5_embedding_ranges[i]
|
||||
|
||||
# 1. txt attends to itself
|
||||
regional_attention_mask[
|
||||
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
] = True
|
||||
|
||||
# 2. txt attends to corresponding regional img
|
||||
# TODO(ryand): Make sure that broadcasting works as expected.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask
|
||||
|
||||
# 3. regional img attends to corresponding txt
|
||||
# TODO(ryand): Make sure that broadcasting works as expected.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask
|
||||
|
||||
# 4. regional img attends to itself
|
||||
# TODO(ryand): Make sure that broadcasting works as expected.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T
|
||||
|
||||
return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_regional_text_conditioning(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user