Be smarter about selecting the global CLIP embedding for FLUX regional prompting.

This commit is contained in:
Ryan Dick
2024-11-25 20:15:04 +00:00
parent 3741a6f5e0
commit 94c088300f

View File

@@ -64,8 +64,9 @@ class RegionalPromptingExtension:
)
# 4. regional img attends to itself
image_mask = image_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T
# image_mask = image_mask.view(img_seq_len, 1)
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T
regional_attention_mask[txt_seq_len:, txt_seq_len:] = True
return regional_attention_mask
@@ -79,6 +80,15 @@ class RegionalPromptingExtension:
concat_image_masks: list[torch.Tensor] = []
concat_t5_embedding_ranges: list[Range] = []
# Choose global CLIP embedding.
# Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
# the first prompt's CLIP embedding.
global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
for text_conditioning in text_conditionings:
if text_conditioning.mask is None:
global_clip_embedding = text_conditioning.clip_embeddings
break
cur_t5_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
@@ -101,8 +111,7 @@ class RegionalPromptingExtension:
return FluxRegionalTextConditioning(
t5_embeddings=t5_embeddings,
# HACK(ryand): Be smarter about how we select which CLIP embedding to use.
clip_embeddings=text_conditionings[0].clip_embeddings,
clip_embeddings=global_clip_embedding,
t5_txt_ids=t5_txt_ids,
image_masks=torch.cat(concat_image_masks, dim=1),
t5_embedding_ranges=concat_t5_embedding_ranges,