mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Use a single global CLIP embedding for FLUX regional guidance.
This commit is contained in:
@@ -76,31 +76,20 @@ class RegionalPromptingExtension:
|
||||
) -> FluxRegionalTextConditioning:
|
||||
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
|
||||
concat_t5_embeddings: list[torch.Tensor] = []
|
||||
concat_clip_embeddings: list[torch.Tensor] = []
|
||||
concat_image_masks: list[torch.Tensor] = []
|
||||
concat_t5_embedding_ranges: list[Range] = []
|
||||
concat_clip_embedding_ranges: list[Range] = []
|
||||
|
||||
cur_t5_embedding_len = 0
|
||||
cur_clip_embedding_len = 0
|
||||
for text_conditioning in text_conditionings:
|
||||
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
|
||||
concat_clip_embeddings.append(text_conditioning.clip_embeddings)
|
||||
|
||||
concat_t5_embedding_ranges.append(
|
||||
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
|
||||
)
|
||||
concat_clip_embedding_ranges.append(
|
||||
Range(
|
||||
start=cur_clip_embedding_len,
|
||||
end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1],
|
||||
)
|
||||
)
|
||||
|
||||
concat_image_masks.append(text_conditioning.mask)
|
||||
|
||||
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
|
||||
cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1]
|
||||
|
||||
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
|
||||
|
||||
@@ -112,11 +101,11 @@ class RegionalPromptingExtension:
|
||||
|
||||
return FluxRegionalTextConditioning(
|
||||
t5_embeddings=t5_embeddings,
|
||||
clip_embeddings=torch.cat(concat_clip_embeddings, dim=1),
|
||||
# HACK(ryand): Be smarter about how we select which CLIP embedding to use.
|
||||
clip_embeddings=text_conditionings[0].clip_embeddings,
|
||||
t5_txt_ids=t5_txt_ids,
|
||||
image_masks=torch.cat(concat_image_masks, dim=1),
|
||||
t5_embedding_ranges=concat_t5_embedding_ranges,
|
||||
clip_embedding_ranges=concat_clip_embedding_ranges,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -15,11 +15,15 @@ class FluxTextConditioning:
|
||||
@dataclass
|
||||
class FluxRegionalTextConditioning:
|
||||
# Concatenated text embeddings.
|
||||
# Shape: (1, concatenated_txt_seq_len, 4096)
|
||||
t5_embeddings: torch.Tensor
|
||||
clip_embeddings: torch.Tensor
|
||||
|
||||
# Shape: (1, concatenated_txt_seq_len, 3)
|
||||
t5_txt_ids: torch.Tensor
|
||||
|
||||
# Global CLIP embeddings.
|
||||
# Shape: (1, 768)
|
||||
clip_embeddings: torch.Tensor
|
||||
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||
# Shape: (1, num_prompts, image_seq_len)
|
||||
# Dtype: torch.bool
|
||||
@@ -27,6 +31,4 @@ class FluxRegionalTextConditioning:
|
||||
|
||||
# List of ranges that represent the embedding ranges for each mask.
|
||||
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
|
||||
# clip_embedding_ranges[i] contains the range of the clip embeddings that correspond to image_masks[i].
|
||||
t5_embedding_ranges: list[Range]
|
||||
clip_embedding_ranges: list[Range]
|
||||
|
||||
Reference in New Issue
Block a user