Compare commits

...

2 Commits

Author SHA1 Message Date
Ryan Dick
8d04ec3f95 Improve docs related to dynamic T5 sequence length selection. 2024-11-29 16:11:51 +00:00
Ryan Dick
4581a37a48 Dynamically select smaller t5 seq len to save inference time. 2024-11-29 16:02:25 +00:00
2 changed files with 43 additions and 11 deletions

View File

@@ -29,7 +29,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
@@ -48,6 +48,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
use_short_t5_seq_len: bool = InputField(
description="Use a shorter sequence length for the T5 encoder if a short prompt is used. This can improve "
+ "performance and reduce peak memory, but may result in slightly different image outputs.",
default=True,
)
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
@@ -74,6 +79,12 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt = [self.prompt]
valid_seq_lens = [self.t5_max_seq_len]
if self.use_short_t5_seq_len:
# We allow a minimum sequence length of 128. Going too short results in more significant image chagnes.
valid_seq_lens = list(range(128, self.t5_max_seq_len, 128))
valid_seq_lens.append(self.t5_max_seq_len)
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
@@ -81,10 +92,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
prompt_embeds = t5_encoder(prompt, valid_seq_lens)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
@@ -122,10 +133,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
pooled_prompt_embeds = clip_encoder(prompt, [77])
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds

View File

@@ -1,32 +1,53 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
self.tokenizer = tokenizer
self.hf_module = encoder
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
"""Encode text into a tensor.
Args:
text: A list of text prompts to encode.
valid_seq_lens: A list of valid sequence lengths. The shortest valid sequence length that can contain the
text will be used. If the largest valid sequence length cannot contain the text, the encoding will be
truncated.
"""
valid_seq_lens = sorted(valid_seq_lens)
# Perform initial encoding with the maximum valid sequence length.
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
max_length=max(valid_seq_lens),
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
seq_len: int = batch_encoding["length"][0].item()
selected_seq_len = valid_seq_lens[-1]
for len in valid_seq_lens:
if len >= seq_len:
selected_seq_len = len
break
input_ids = batch_encoding["input_ids"][..., :selected_seq_len]
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
input_ids=input_ids.to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)