Tweak flux regional prompting attention scheme based on latest experimentation.

This commit is contained in:
Ryan Dick
2024-11-27 22:13:07 +00:00
parent fa5653cdf7
commit e970185161

View File

@@ -154,24 +154,43 @@ class RegionalPromptingExtension:
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
] = 1.0
if image_mask is None:
continue
if image_mask is not None:
# 2. txt attends to corresponding regional img
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
image_mask.view(1, img_seq_len)
)
# 2. txt attends to corresponding regional img
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view(
1, img_seq_len
)
# 3. regional img attends to corresponding txt
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
image_mask.view(img_seq_len, 1)
)
# 3. regional img attends to corresponding txt
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view(
img_seq_len, 1
)
# 4. regional img attends to itself
image_mask = image_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
else:
if background_region_mask is None:
# There are no region masks, so we don't need to do anything here - this case is handled below.
continue
# 4. regional img attends to itself
image_mask = image_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
# that regions focus on their local prompts. We do, however, allow attention between background regions
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
# embeddings, which we found experimentally to cause artifacts.
# 2. global txt attends to background region
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
background_region_mask.view(1, img_seq_len)
)
# 3. background region attends to global txt
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
background_region_mask.view(img_seq_len, 1)
)
# Handle image background regions.
if background_region_mask is None:
@@ -179,9 +198,9 @@ class RegionalPromptingExtension:
regional_attention_mask[txt_seq_len:, :] = 1.0
regional_attention_mask[:, txt_seq_len:] = 1.0
else:
# Allow background regions to attend to themselves and to the entire txt embedding.
regional_attention_mask[txt_seq_len:, :] += background_region_mask.view(img_seq_len, 1)
regional_attention_mask[:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
# Allow background regions to attend to themselves.
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
# Convert attention mask to boolean.
regional_attention_mask = regional_attention_mask > 0.5