mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-04 09:15:11 -05:00
Distinguish between restricted and unrestricted attn masks in FLUX regional prompting.
This commit is contained in:
@@ -134,5 +134,5 @@ class CustomSingleStreamBlockProcessor:
|
||||
"""A custom implementation of SingleStreamBlock.forward() with additional features:
|
||||
- Masking
|
||||
"""
|
||||
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
|
||||
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
|
||||
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)
|
||||
|
||||
@@ -18,18 +18,20 @@ class RegionalPromptingExtension:
|
||||
def __init__(
|
||||
self,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
attn_mask_with_restricted_img_self_attn: torch.Tensor | None = None,
|
||||
attn_mask_with_unrestricted_img_self_attn: torch.Tensor | None = None,
|
||||
restricted_attn_mask: torch.Tensor | None = None,
|
||||
# unrestricted_attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
self.regional_text_conditioning = regional_text_conditioning
|
||||
self.attn_mask_with_restricted_img_self_attn = attn_mask_with_restricted_img_self_attn
|
||||
self.attn_mask_with_unrestricted_img_self_attn = attn_mask_with_unrestricted_img_self_attn
|
||||
self.restricted_attn_mask = restricted_attn_mask
|
||||
# self.unrestricted_attn_mask = unrestricted_attn_mask
|
||||
|
||||
def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
return self.attn_mask_with_unrestricted_img_self_attn
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
def get_single_stream_attn_mask(self) -> torch.Tensor | None:
|
||||
return self.attn_mask_with_unrestricted_img_self_attn
|
||||
def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
@classmethod
|
||||
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
|
||||
@@ -40,37 +42,34 @@ class RegionalPromptingExtension:
|
||||
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
|
||||
"""
|
||||
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
|
||||
attn_mask_with_restricted_img_self_attn = cls._prepare_attn_mask(
|
||||
regional_text_conditioning, img_seq_len, restrict_img_self_attn=True
|
||||
)
|
||||
attn_mask_with_unrestricted_img_self_attn = cls._prepare_attn_mask(
|
||||
regional_text_conditioning, img_seq_len, restrict_img_self_attn=False
|
||||
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
|
||||
regional_text_conditioning, img_seq_len
|
||||
)
|
||||
# attn_mask_with_unrestricted_img_self_attn = cls._prepare_unrestricted_attn_mask(
|
||||
# regional_text_conditioning, img_seq_len
|
||||
# )
|
||||
return cls(
|
||||
regional_text_conditioning=regional_text_conditioning,
|
||||
attn_mask_with_restricted_img_self_attn=attn_mask_with_restricted_img_self_attn,
|
||||
attn_mask_with_unrestricted_img_self_attn=attn_mask_with_unrestricted_img_self_attn,
|
||||
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
|
||||
# unrestricted_attn_mask=attn_mask_with_unrestricted_img_self_attn,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prepare_attn_mask(
|
||||
def _prepare_unrestricted_attn_mask(
|
||||
cls,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
img_seq_len: int,
|
||||
restrict_img_self_attn: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
|
||||
- img self-attention is not masked.
|
||||
- img regions attend to both txt within their own region and to global prompts.
|
||||
"""
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Infer txt_seq_len from the t5_embeddings tensor.
|
||||
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# Decide whether to compute the img self-attention region mask.
|
||||
# When compute_img_self_attn_region_mask is True, img self attention is only allowed within regions.
|
||||
# When compute_img_self_attn_region_mask is False, img self attention is not constrained.
|
||||
has_region_masks = any(mask is not None for mask in regional_text_conditioning.image_masks)
|
||||
compute_img_self_attn_region_mask = restrict_img_self_attn and has_region_masks
|
||||
|
||||
# In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# 1. txt attends to itself
|
||||
@@ -101,14 +100,87 @@ class RegionalPromptingExtension:
|
||||
fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
|
||||
|
||||
# 4. regional img attends to itself
|
||||
if compute_img_self_attn_region_mask and image_mask is not None:
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
|
||||
# 4. regional img attends to itself
|
||||
# Allow unrestricted img self attention.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
|
||||
|
||||
if not compute_img_self_attn_region_mask:
|
||||
# Allow unrestricted img self attention.
|
||||
# Convert attention mask to boolean.
|
||||
regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _prepare_restricted_attn_mask(
|
||||
cls,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
img_seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
|
||||
- img self-attention is only allowed within regions.
|
||||
- img regions only attend to txt within their own region, not to global prompts.
|
||||
|
||||
"""
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Infer txt_seq_len from the t5_embeddings tensor.
|
||||
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# 1. txt attends to itself
|
||||
# 2. txt attends to corresponding regional img
|
||||
# 3. regional img attends to corresponding txt
|
||||
# 4. regional img attends to itself
|
||||
|
||||
# Initialize empty attention mask.
|
||||
regional_attention_mask = torch.zeros(
|
||||
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
)
|
||||
|
||||
# Identify background region. I.e. the region that is not covered by any region masks.
|
||||
background_region_mask: None | torch.Tensor = None
|
||||
for image_mask in regional_text_conditioning.image_masks:
|
||||
if image_mask is not None:
|
||||
if background_region_mask is None:
|
||||
background_region_mask = torch.ones_like(image_mask)
|
||||
background_region_mask *= 1 - image_mask
|
||||
|
||||
for image_mask, t5_embedding_range in zip(
|
||||
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
):
|
||||
# 1. txt attends to itself
|
||||
regional_attention_mask[
|
||||
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
] = 1.0
|
||||
|
||||
if image_mask is None:
|
||||
continue
|
||||
|
||||
# 2. txt attends to corresponding regional img
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view(
|
||||
1, img_seq_len
|
||||
)
|
||||
|
||||
# 3. regional img attends to corresponding txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view(
|
||||
img_seq_len, 1
|
||||
)
|
||||
|
||||
# 4. regional img attends to itself
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
|
||||
|
||||
# Handle image background regions.
|
||||
if background_region_mask is None:
|
||||
# There are no region masks, so allow unrestricted img self attention.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
|
||||
else:
|
||||
# Allow background regions to attend to themselves and to the rest of the image.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
|
||||
|
||||
# Convert attention mask to boolean.
|
||||
regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
Reference in New Issue
Block a user