Distinguish between restricted and unrestricted attn masks in FLUX regional prompting.

This commit is contained in:
Ryan Dick
2024-11-26 16:55:52 +00:00
parent e01f66b026
commit faee79dc95
2 changed files with 102 additions and 30 deletions

View File

@@ -134,5 +134,5 @@ class CustomSingleStreamBlockProcessor:
"""A custom implementation of SingleStreamBlock.forward() with additional features:
- Masking
"""
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)

View File

@@ -18,18 +18,20 @@ class RegionalPromptingExtension:
def __init__(
self,
regional_text_conditioning: FluxRegionalTextConditioning,
attn_mask_with_restricted_img_self_attn: torch.Tensor | None = None,
attn_mask_with_unrestricted_img_self_attn: torch.Tensor | None = None,
restricted_attn_mask: torch.Tensor | None = None,
# unrestricted_attn_mask: torch.Tensor | None = None,
):
self.regional_text_conditioning = regional_text_conditioning
self.attn_mask_with_restricted_img_self_attn = attn_mask_with_restricted_img_self_attn
self.attn_mask_with_unrestricted_img_self_attn = attn_mask_with_unrestricted_img_self_attn
self.restricted_attn_mask = restricted_attn_mask
# self.unrestricted_attn_mask = unrestricted_attn_mask
def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
return self.attn_mask_with_unrestricted_img_self_attn
order = [self.restricted_attn_mask, None]
return order[block_index % len(order)]
def get_single_stream_attn_mask(self) -> torch.Tensor | None:
return self.attn_mask_with_unrestricted_img_self_attn
def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
order = [self.restricted_attn_mask, None]
return order[block_index % len(order)]
@classmethod
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
@@ -40,37 +42,34 @@ class RegionalPromptingExtension:
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
"""
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
attn_mask_with_restricted_img_self_attn = cls._prepare_attn_mask(
regional_text_conditioning, img_seq_len, restrict_img_self_attn=True
)
attn_mask_with_unrestricted_img_self_attn = cls._prepare_attn_mask(
regional_text_conditioning, img_seq_len, restrict_img_self_attn=False
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
regional_text_conditioning, img_seq_len
)
# attn_mask_with_unrestricted_img_self_attn = cls._prepare_unrestricted_attn_mask(
# regional_text_conditioning, img_seq_len
# )
return cls(
regional_text_conditioning=regional_text_conditioning,
attn_mask_with_restricted_img_self_attn=attn_mask_with_restricted_img_self_attn,
attn_mask_with_unrestricted_img_self_attn=attn_mask_with_unrestricted_img_self_attn,
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
# unrestricted_attn_mask=attn_mask_with_unrestricted_img_self_attn,
)
@classmethod
def _prepare_attn_mask(
def _prepare_unrestricted_attn_mask(
cls,
regional_text_conditioning: FluxRegionalTextConditioning,
img_seq_len: int,
restrict_img_self_attn: bool,
) -> torch.Tensor:
"""Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
- img self-attention is not masked.
- img regions attend to both txt within their own region and to global prompts.
"""
device = TorchDevice.choose_torch_device()
# Infer txt_seq_len from the t5_embeddings tensor.
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
# Decide whether to compute the img self-attention region mask.
# When compute_img_self_attn_region_mask is True, img self attention is only allowed within regions.
# When compute_img_self_attn_region_mask is False, img self attention is not constrained.
has_region_masks = any(mask is not None for mask in regional_text_conditioning.image_masks)
compute_img_self_attn_region_mask = restrict_img_self_attn and has_region_masks
# In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# Concatenation happens in the following order: [txt_seq, img_seq].
# There are 4 portions of the attention mask to consider as we prepare it:
# 1. txt attends to itself
@@ -101,14 +100,87 @@ class RegionalPromptingExtension:
fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
# 4. regional img attends to itself
if compute_img_self_attn_region_mask and image_mask is not None:
image_mask = image_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
# 4. regional img attends to itself
# Allow unrestricted img self attention.
regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
if not compute_img_self_attn_region_mask:
# Allow unrestricted img self attention.
# Convert attention mask to boolean.
regional_attention_mask = regional_attention_mask > 0.5
return regional_attention_mask
@classmethod
def _prepare_restricted_attn_mask(
cls,
regional_text_conditioning: FluxRegionalTextConditioning,
img_seq_len: int,
) -> torch.Tensor:
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
- img self-attention is only allowed within regions.
- img regions only attend to txt within their own region, not to global prompts.
"""
device = TorchDevice.choose_torch_device()
# Infer txt_seq_len from the t5_embeddings tensor.
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# Concatenation happens in the following order: [txt_seq, img_seq].
# There are 4 portions of the attention mask to consider as we prepare it:
# 1. txt attends to itself
# 2. txt attends to corresponding regional img
# 3. regional img attends to corresponding txt
# 4. regional img attends to itself
# Initialize empty attention mask.
regional_attention_mask = torch.zeros(
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
)
# Identify background region. I.e. the region that is not covered by any region masks.
background_region_mask: None | torch.Tensor = None
for image_mask in regional_text_conditioning.image_masks:
if image_mask is not None:
if background_region_mask is None:
background_region_mask = torch.ones_like(image_mask)
background_region_mask *= 1 - image_mask
for image_mask, t5_embedding_range in zip(
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
):
# 1. txt attends to itself
regional_attention_mask[
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
] = 1.0
if image_mask is None:
continue
# 2. txt attends to corresponding regional img
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view(
1, img_seq_len
)
# 3. regional img attends to corresponding txt
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view(
img_seq_len, 1
)
# 4. regional img attends to itself
image_mask = image_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
# Handle image background regions.
if background_region_mask is None:
# There are no region masks, so allow unrestricted img self attention.
regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
else:
# Allow background regions to attend to themselves and to the rest of the image.
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
# Convert attention mask to boolean.
regional_attention_mask = regional_attention_mask > 0.5