diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index c3eb8e542f..f5f203af69 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -154,24 +154,43 @@ class RegionalPromptingExtension: t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end ] = 1.0 - if image_mask is None: - continue + if image_mask is not None: + # 2. txt attends to corresponding regional img + # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = ( + image_mask.view(1, img_seq_len) + ) - # 2. txt attends to corresponding regional img - # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. - regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view( - 1, img_seq_len - ) + # 3. regional img attends to corresponding txt + # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = ( + image_mask.view(img_seq_len, 1) + ) - # 3. regional img attends to corresponding txt - # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. - regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view( - img_seq_len, 1 - ) + # 4. regional img attends to itself + image_mask = image_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + else: + if background_region_mask is None: + # There are no region masks, so we don't need to do anything here - this case is handled below. + continue - # 4. regional img attends to itself - image_mask = image_mask.view(img_seq_len, 1) - regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + # We don't allow attention between non-background image regions and global prompts. This helps to ensure + # that regions focus on their local prompts. We do, however, allow attention between background regions + # and global prompts. If we didn't do this, then the background regions would not attend to any txt + # embeddings, which we found experimentally to cause artifacts. + + # 2. global txt attends to background region + # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = ( + background_region_mask.view(1, img_seq_len) + ) + + # 3. background region attends to global txt + # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = ( + background_region_mask.view(img_seq_len, 1) + ) # Handle image background regions. if background_region_mask is None: @@ -179,9 +198,9 @@ class RegionalPromptingExtension: regional_attention_mask[txt_seq_len:, :] = 1.0 regional_attention_mask[:, txt_seq_len:] = 1.0 else: - # Allow background regions to attend to themselves and to the entire txt embedding. - regional_attention_mask[txt_seq_len:, :] += background_region_mask.view(img_seq_len, 1) - regional_attention_mask[:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) + # Allow background regions to attend to themselves. + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) # Convert attention mask to boolean. regional_attention_mask = regional_attention_mask > 0.5