diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 9131f06ea3..22b57540eb 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -1,5 +1,5 @@ from contextlib import ExitStack -from typing import Iterator, Literal, Optional, Tuple +from typing import Iterator, Literal, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast @@ -111,6 +111,9 @@ class FluxTextEncoderInvocation(BaseInvocation): t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len) + if context.config.get().log_tokenization: + self._log_t5_tokenization(context, t5_tokenizer) + context.util.signal_progress("Running T5 encoder") prompt_embeds = t5_encoder(prompt) @@ -151,6 +154,9 @@ class FluxTextEncoderInvocation(BaseInvocation): clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77) + if context.config.get().log_tokenization: + self._log_clip_tokenization(context, clip_tokenizer) + context.util.signal_progress("Running CLIP encoder") pooled_prompt_embeds = clip_encoder(prompt) @@ -170,3 +176,88 @@ class FluxTextEncoderInvocation(BaseInvocation): assert isinstance(lora_info.model, ModelPatchRaw) yield (lora_info.model, lora.weight) del lora_info + + def _log_t5_tokenization( + self, + context: InvocationContext, + tokenizer: Union[T5Tokenizer, T5TokenizerFast], + ) -> None: + """Logs the tokenization of a prompt for a T5-based model like FLUX.""" + + # Tokenize the prompt using the same parameters as the model's text encoder. + # T5 tokenizers add an EOS token () and then pad to max_length. + tokenized_output = tokenizer( + self.prompt, + padding="max_length", + max_length=self.t5_max_seq_len, + truncation=True, + add_special_tokens=True, # This is important for T5 to add the EOS token. + return_tensors="pt", + ) + + input_ids = tokenized_output.input_ids[0] + tokens = tokenizer.convert_ids_to_tokens(input_ids) + + # The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces. + # We'll replace it with a regular space for readability. + tokens = [t.replace(" ", " ") for t in tokens] + + tokenized_str = "" + used_tokens = 0 + for token in tokens: + if token == tokenizer.eos_token: + tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS + used_tokens += 1 + elif token == tokenizer.pad_token: + # tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD + continue + else: + color = (used_tokens % 6) + 1 # Cycle through 6 colors + tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m" + used_tokens += 1 + + context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):") + context.logger.info(f"{tokenized_str}\x1b[0m") + + def _log_clip_tokenization( + self, + context: InvocationContext, + tokenizer: CLIPTokenizer, + ) -> None: + """Logs the tokenization of a prompt for a CLIP-based model.""" + max_length = tokenizer.model_max_length + + tokenized_output = tokenizer( + self.prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + input_ids = tokenized_output.input_ids[0] + attention_mask = tokenized_output.attention_mask[0] + tokens = tokenizer.convert_ids_to_tokens(input_ids) + + # The CLIP tokenizer uses '' to denote spaces. + # We'll replace it with a regular space for readability. + tokens = [t.replace("", " ") for t in tokens] + + tokenized_str = "" + used_tokens = 0 + for i, token in enumerate(tokens): + if attention_mask[i] == 0: + # Do not log padding tokens. + continue + + if token == tokenizer.bos_token: + tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS + elif token == tokenizer.eos_token: + tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS + else: + color = (used_tokens % 6) + 1 # Cycle through 6 colors + tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m" + used_tokens += 1 + + context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):") + context.logger.info(f"{tokenized_str}\x1b[0m")