mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Tweak flux regional prompting attention scheme based on latest experimentation.
This commit is contained in:
@@ -154,24 +154,43 @@ class RegionalPromptingExtension:
|
||||
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
] = 1.0
|
||||
|
||||
if image_mask is None:
|
||||
continue
|
||||
if image_mask is not None:
|
||||
# 2. txt attends to corresponding regional img
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
image_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 2. txt attends to corresponding regional img
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view(
|
||||
1, img_seq_len
|
||||
)
|
||||
# 3. regional img attends to corresponding txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
image_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# 3. regional img attends to corresponding txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view(
|
||||
img_seq_len, 1
|
||||
)
|
||||
# 4. regional img attends to itself
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
|
||||
else:
|
||||
if background_region_mask is None:
|
||||
# There are no region masks, so we don't need to do anything here - this case is handled below.
|
||||
continue
|
||||
|
||||
# 4. regional img attends to itself
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
|
||||
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
|
||||
# that regions focus on their local prompts. We do, however, allow attention between background regions
|
||||
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
|
||||
# embeddings, which we found experimentally to cause artifacts.
|
||||
|
||||
# 2. global txt attends to background region
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
background_region_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. background region attends to global txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
background_region_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# Handle image background regions.
|
||||
if background_region_mask is None:
|
||||
@@ -179,9 +198,9 @@ class RegionalPromptingExtension:
|
||||
regional_attention_mask[txt_seq_len:, :] = 1.0
|
||||
regional_attention_mask[:, txt_seq_len:] = 1.0
|
||||
else:
|
||||
# Allow background regions to attend to themselves and to the entire txt embedding.
|
||||
regional_attention_mask[txt_seq_len:, :] += background_region_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
|
||||
# Allow background regions to attend to themselves.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
|
||||
|
||||
# Convert attention mask to boolean.
|
||||
regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
Reference in New Issue
Block a user