diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py index dd180eb56c..0f56adacde 100644 --- a/invokeai/backend/flux/custom_block_processor.py +++ b/invokeai/backend/flux/custom_block_processor.py @@ -134,5 +134,5 @@ class CustomSingleStreamBlockProcessor: """A custom implementation of SingleStreamBlock.forward() with additional features: - Masking """ - attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index) + attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index) return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 259db7c29a..7d51e12508 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -18,18 +18,20 @@ class RegionalPromptingExtension: def __init__( self, regional_text_conditioning: FluxRegionalTextConditioning, - attn_mask_with_restricted_img_self_attn: torch.Tensor | None = None, - attn_mask_with_unrestricted_img_self_attn: torch.Tensor | None = None, + restricted_attn_mask: torch.Tensor | None = None, + # unrestricted_attn_mask: torch.Tensor | None = None, ): self.regional_text_conditioning = regional_text_conditioning - self.attn_mask_with_restricted_img_self_attn = attn_mask_with_restricted_img_self_attn - self.attn_mask_with_unrestricted_img_self_attn = attn_mask_with_unrestricted_img_self_attn + self.restricted_attn_mask = restricted_attn_mask + # self.unrestricted_attn_mask = unrestricted_attn_mask def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None: - return self.attn_mask_with_unrestricted_img_self_attn + order = [self.restricted_attn_mask, None] + return order[block_index % len(order)] - def get_single_stream_attn_mask(self) -> torch.Tensor | None: - return self.attn_mask_with_unrestricted_img_self_attn + def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None: + order = [self.restricted_attn_mask, None] + return order[block_index % len(order)] @classmethod def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int): @@ -40,37 +42,34 @@ class RegionalPromptingExtension: img_seq_len (int): The image sequence length (i.e. packed_height * packed_width). """ regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning) - attn_mask_with_restricted_img_self_attn = cls._prepare_attn_mask( - regional_text_conditioning, img_seq_len, restrict_img_self_attn=True - ) - attn_mask_with_unrestricted_img_self_attn = cls._prepare_attn_mask( - regional_text_conditioning, img_seq_len, restrict_img_self_attn=False + attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask( + regional_text_conditioning, img_seq_len ) + # attn_mask_with_unrestricted_img_self_attn = cls._prepare_unrestricted_attn_mask( + # regional_text_conditioning, img_seq_len + # ) return cls( regional_text_conditioning=regional_text_conditioning, - attn_mask_with_restricted_img_self_attn=attn_mask_with_restricted_img_self_attn, - attn_mask_with_unrestricted_img_self_attn=attn_mask_with_unrestricted_img_self_attn, + restricted_attn_mask=attn_mask_with_restricted_img_self_attn, + # unrestricted_attn_mask=attn_mask_with_unrestricted_img_self_attn, ) @classmethod - def _prepare_attn_mask( + def _prepare_unrestricted_attn_mask( cls, regional_text_conditioning: FluxRegionalTextConditioning, img_seq_len: int, - restrict_img_self_attn: bool, ) -> torch.Tensor: + """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that: + - img self-attention is not masked. + - img regions attend to both txt within their own region and to global prompts. + """ device = TorchDevice.choose_torch_device() # Infer txt_seq_len from the t5_embeddings tensor. txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] - # Decide whether to compute the img self-attention region mask. - # When compute_img_self_attn_region_mask is True, img self attention is only allowed within regions. - # When compute_img_self_attn_region_mask is False, img self attention is not constrained. - has_region_masks = any(mask is not None for mask in regional_text_conditioning.image_masks) - compute_img_self_attn_region_mask = restrict_img_self_attn and has_region_masks - - # In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. # Concatenation happens in the following order: [txt_seq, img_seq]. # There are 4 portions of the attention mask to consider as we prepare it: # 1. txt attends to itself @@ -101,14 +100,87 @@ class RegionalPromptingExtension: fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0 regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value - # 4. regional img attends to itself - if compute_img_self_attn_region_mask and image_mask is not None: - image_mask = image_mask.view(img_seq_len, 1) - regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + # 4. regional img attends to itself + # Allow unrestricted img self attention. + regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 - if not compute_img_self_attn_region_mask: - # Allow unrestricted img self attention. + # Convert attention mask to boolean. + regional_attention_mask = regional_attention_mask > 0.5 + + return regional_attention_mask + + @classmethod + def _prepare_restricted_attn_mask( + cls, + regional_text_conditioning: FluxRegionalTextConditioning, + img_seq_len: int, + ) -> torch.Tensor: + """Prepare a 'restricted' attention mask. In this context, 'restricted' means that: + - img self-attention is only allowed within regions. + - img regions only attend to txt within their own region, not to global prompts. + + """ + device = TorchDevice.choose_torch_device() + + # Infer txt_seq_len from the t5_embeddings tensor. + txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] + + # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # Concatenation happens in the following order: [txt_seq, img_seq]. + # There are 4 portions of the attention mask to consider as we prepare it: + # 1. txt attends to itself + # 2. txt attends to corresponding regional img + # 3. regional img attends to corresponding txt + # 4. regional img attends to itself + + # Initialize empty attention mask. + regional_attention_mask = torch.zeros( + (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 + ) + + # Identify background region. I.e. the region that is not covered by any region masks. + background_region_mask: None | torch.Tensor = None + for image_mask in regional_text_conditioning.image_masks: + if image_mask is not None: + if background_region_mask is None: + background_region_mask = torch.ones_like(image_mask) + background_region_mask *= 1 - image_mask + + for image_mask, t5_embedding_range in zip( + regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True + ): + # 1. txt attends to itself + regional_attention_mask[ + t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end + ] = 1.0 + + if image_mask is None: + continue + + # 2. txt attends to corresponding regional img + # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view( + 1, img_seq_len + ) + + # 3. regional img attends to corresponding txt + # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view( + img_seq_len, 1 + ) + + # 4. regional img attends to itself + image_mask = image_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + + # Handle image background regions. + if background_region_mask is None: + # There are no region masks, so allow unrestricted img self attention. regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 + else: + # Allow background regions to attend to themselves and to the rest of the image. + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) # Convert attention mask to boolean. regional_attention_mask = regional_attention_mask > 0.5