Short-circuit if there are no region masks in FLUX and don't apply attention masking.

This commit is contained in:
Ryan Dick
2024-11-27 22:40:10 +00:00
parent 6565cea039
commit 64364e7911

View File

@@ -110,12 +110,25 @@ class RegionalPromptingExtension:
cls,
regional_text_conditioning: FluxRegionalTextConditioning,
img_seq_len: int,
) -> torch.Tensor:
) -> torch.Tensor | None:
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
- img self-attention is only allowed within regions.
- img regions only attend to txt within their own region, not to global prompts.
"""
# Identify background region. I.e. the region that is not covered by any region masks.
background_region_mask: None | torch.Tensor = None
for image_mask in regional_text_conditioning.image_masks:
if image_mask is not None:
if background_region_mask is None:
background_region_mask = torch.ones_like(image_mask)
background_region_mask *= 1 - image_mask
if background_region_mask is None:
# There are no region masks, short-circuit and return None.
# TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
# is a rare use case and would make the logic here significantly more complicated.
return None
device = TorchDevice.choose_torch_device()
# Infer txt_seq_len from the t5_embeddings tensor.
@@ -134,14 +147,6 @@ class RegionalPromptingExtension:
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
)
# Identify background region. I.e. the region that is not covered by any region masks.
background_region_mask: None | torch.Tensor = None
for image_mask in regional_text_conditioning.image_masks:
if image_mask is not None:
if background_region_mask is None:
background_region_mask = torch.ones_like(image_mask)
background_region_mask *= 1 - image_mask
for image_mask, t5_embedding_range in zip(
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
):
@@ -167,10 +172,6 @@ class RegionalPromptingExtension:
image_mask = image_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
else:
if background_region_mask is None:
# There are no region masks, so we don't need to do anything here - this case is handled below.
continue
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
# that regions focus on their local prompts. We do, however, allow attention between background regions
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
@@ -188,15 +189,9 @@ class RegionalPromptingExtension:
background_region_mask.view(img_seq_len, 1)
)
# Handle image background regions.
if background_region_mask is None:
# There are no region masks, so allow unrestricted img-img attention, and unrestricted img-txt attention.
regional_attention_mask[txt_seq_len:, :] = 1.0
regional_attention_mask[:, txt_seq_len:] = 1.0
else:
# Allow background regions to attend to themselves.
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
# Allow background regions to attend to themselves.
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
# Convert attention mask to boolean.
regional_attention_mask = regional_attention_mask > 0.5