diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 21b2279b14..43a3cc4395 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -10,7 +10,10 @@ from invokeai.backend.util.mask import to_standard_float_mask class RegionalPromptingExtension: - """A class for managing regional prompting with FLUX.""" + """A class for managing regional prompting with FLUX. + + Implementation inspired by: https://arxiv.org/pdf/2411.02395 + """ def __init__(self, regional_text_conditioning: FluxRegionalTextConditioning): self.regional_text_conditioning = regional_text_conditioning @@ -19,6 +22,51 @@ class RegionalPromptingExtension: def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning]): return cls(regional_text_conditioning=cls._concat_regional_text_conditioning(text_conditioning)) + def _prepare_attn_mask(self) -> torch.Tensor: + device = self.regional_text_conditioning.image_masks[0].device + # img_seq_len = latent_height * latent_width + img_seq_len = ( + self.regional_text_conditioning.image_masks.shape[-1] + * self.regional_text_conditioning.image_masks.shape[-2] + ) + txt_seq_len = self.regional_text_conditioning.t5_embeddings.shape[1] + + # In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # Concatenation happens in the following order: [txt_seq, img_seq]. + # There are 4 portions of the attention mask to consider as we prepare it: + # 1. txt attends to itself + # 2. txt attends to corresponding regional img + # 3. regional img attends to corresponding txt + # 4. regional img attends to itself + + # Initialize empty attention mask. + regional_attention_mask = torch.zeros( + (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.bool + ) + + for i in range(len(self.regional_text_conditioning.t5_embeddings)): + image_mask = self.regional_text_conditioning.image_masks[i].flatten() + t5_embedding_range = self.regional_text_conditioning.t5_embedding_ranges[i] + + # 1. txt attends to itself + regional_attention_mask[ + t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end + ] = True + + # 2. txt attends to corresponding regional img + # TODO(ryand): Make sure that broadcasting works as expected. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask + + # 3. regional img attends to corresponding txt + # TODO(ryand): Make sure that broadcasting works as expected. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask + + # 4. regional img attends to itself + # TODO(ryand): Make sure that broadcasting works as expected. + regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T + + return regional_attention_mask + @classmethod def _concat_regional_text_conditioning( cls,