mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Be smarter about selecting the global CLIP embedding for FLUX regional prompting.
This commit is contained in:
@@ -64,8 +64,9 @@ class RegionalPromptingExtension:
|
||||
)
|
||||
|
||||
# 4. regional img attends to itself
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T
|
||||
# image_mask = image_mask.view(img_seq_len, 1)
|
||||
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] = True
|
||||
|
||||
return regional_attention_mask
|
||||
|
||||
@@ -79,6 +80,15 @@ class RegionalPromptingExtension:
|
||||
concat_image_masks: list[torch.Tensor] = []
|
||||
concat_t5_embedding_ranges: list[Range] = []
|
||||
|
||||
# Choose global CLIP embedding.
|
||||
# Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
|
||||
# the first prompt's CLIP embedding.
|
||||
global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
|
||||
for text_conditioning in text_conditionings:
|
||||
if text_conditioning.mask is None:
|
||||
global_clip_embedding = text_conditioning.clip_embeddings
|
||||
break
|
||||
|
||||
cur_t5_embedding_len = 0
|
||||
for text_conditioning in text_conditionings:
|
||||
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
|
||||
@@ -101,8 +111,7 @@ class RegionalPromptingExtension:
|
||||
|
||||
return FluxRegionalTextConditioning(
|
||||
t5_embeddings=t5_embeddings,
|
||||
# HACK(ryand): Be smarter about how we select which CLIP embedding to use.
|
||||
clip_embeddings=text_conditionings[0].clip_embeddings,
|
||||
clip_embeddings=global_clip_embedding,
|
||||
t5_txt_ids=t5_txt_ids,
|
||||
image_masks=torch.cat(concat_image_masks, dim=1),
|
||||
t5_embedding_ranges=concat_t5_embedding_ranges,
|
||||
|
||||
Reference in New Issue
Block a user