WIP - add rough logic for preparing the FLUX regional prompting attention mask.

This commit is contained in:
Ryan Dick
2024-11-20 22:29:36 +00:00
parent fda7aaa7ca
commit bad1149504

View File

@@ -10,7 +10,10 @@ from invokeai.backend.util.mask import to_standard_float_mask
class RegionalPromptingExtension:
"""A class for managing regional prompting with FLUX."""
"""A class for managing regional prompting with FLUX.
Implementation inspired by: https://arxiv.org/pdf/2411.02395
"""
def __init__(self, regional_text_conditioning: FluxRegionalTextConditioning):
self.regional_text_conditioning = regional_text_conditioning
@@ -19,6 +22,51 @@ class RegionalPromptingExtension:
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning]):
return cls(regional_text_conditioning=cls._concat_regional_text_conditioning(text_conditioning))
def _prepare_attn_mask(self) -> torch.Tensor:
device = self.regional_text_conditioning.image_masks[0].device
# img_seq_len = latent_height * latent_width
img_seq_len = (
self.regional_text_conditioning.image_masks.shape[-1]
* self.regional_text_conditioning.image_masks.shape[-2]
)
txt_seq_len = self.regional_text_conditioning.t5_embeddings.shape[1]
# In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# Concatenation happens in the following order: [txt_seq, img_seq].
# There are 4 portions of the attention mask to consider as we prepare it:
# 1. txt attends to itself
# 2. txt attends to corresponding regional img
# 3. regional img attends to corresponding txt
# 4. regional img attends to itself
# Initialize empty attention mask.
regional_attention_mask = torch.zeros(
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.bool
)
for i in range(len(self.regional_text_conditioning.t5_embeddings)):
image_mask = self.regional_text_conditioning.image_masks[i].flatten()
t5_embedding_range = self.regional_text_conditioning.t5_embedding_ranges[i]
# 1. txt attends to itself
regional_attention_mask[
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
] = True
# 2. txt attends to corresponding regional img
# TODO(ryand): Make sure that broadcasting works as expected.
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask
# 3. regional img attends to corresponding txt
# TODO(ryand): Make sure that broadcasting works as expected.
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask
# 4. regional img attends to itself
# TODO(ryand): Make sure that broadcasting works as expected.
regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T
return regional_attention_mask
@classmethod
def _concat_regional_text_conditioning(
cls,