Use a single global CLIP embedding for FLUX regional guidance.

This commit is contained in:
Ryan Dick
2024-11-22 23:01:43 +00:00
parent 20356c0746
commit 2c23b8414c
2 changed files with 8 additions and 17 deletions

View File

@@ -76,31 +76,20 @@ class RegionalPromptingExtension:
) -> FluxRegionalTextConditioning:
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
concat_t5_embeddings: list[torch.Tensor] = []
concat_clip_embeddings: list[torch.Tensor] = []
concat_image_masks: list[torch.Tensor] = []
concat_t5_embedding_ranges: list[Range] = []
concat_clip_embedding_ranges: list[Range] = []
cur_t5_embedding_len = 0
cur_clip_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
concat_clip_embeddings.append(text_conditioning.clip_embeddings)
concat_t5_embedding_ranges.append(
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
)
concat_clip_embedding_ranges.append(
Range(
start=cur_clip_embedding_len,
end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1],
)
)
concat_image_masks.append(text_conditioning.mask)
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1]
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
@@ -112,11 +101,11 @@ class RegionalPromptingExtension:
return FluxRegionalTextConditioning(
t5_embeddings=t5_embeddings,
clip_embeddings=torch.cat(concat_clip_embeddings, dim=1),
# HACK(ryand): Be smarter about how we select which CLIP embedding to use.
clip_embeddings=text_conditionings[0].clip_embeddings,
t5_txt_ids=t5_txt_ids,
image_masks=torch.cat(concat_image_masks, dim=1),
t5_embedding_ranges=concat_t5_embedding_ranges,
clip_embedding_ranges=concat_clip_embedding_ranges,
)
@staticmethod

View File

@@ -15,11 +15,15 @@ class FluxTextConditioning:
@dataclass
class FluxRegionalTextConditioning:
# Concatenated text embeddings.
# Shape: (1, concatenated_txt_seq_len, 4096)
t5_embeddings: torch.Tensor
clip_embeddings: torch.Tensor
# Shape: (1, concatenated_txt_seq_len, 3)
t5_txt_ids: torch.Tensor
# Global CLIP embeddings.
# Shape: (1, 768)
clip_embeddings: torch.Tensor
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, image_seq_len)
# Dtype: torch.bool
@@ -27,6 +31,4 @@ class FluxRegionalTextConditioning:
# List of ranges that represent the embedding ranges for each mask.
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
# clip_embedding_ranges[i] contains the range of the clip embeddings that correspond to image_masks[i].
t5_embedding_ranges: list[Range]
clip_embedding_ranges: list[Range]