mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 13:25:22 -05:00
Update flux_text_encoder.py
Added tokenizer logging to flux
This commit is contained in:
committed by
psychedelicious
parent
9066dc1839
commit
bbb5d68146
@@ -1,5 +1,5 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
from typing import Iterator, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
|
||||
@@ -111,6 +111,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
self._log_t5_tokenization(context, t5_tokenizer)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
@@ -151,6 +154,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
self._log_clip_tokenization(context, clip_tokenizer)
|
||||
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
@@ -170,3 +176,88 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _log_t5_tokenization(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
|
||||
) -> None:
|
||||
"""Logs the tokenization of a prompt for a T5-based model like FLUX."""
|
||||
|
||||
# Tokenize the prompt using the same parameters as the model's text encoder.
|
||||
# T5 tokenizers add an EOS token (</s>) and then pad to max_length.
|
||||
tokenized_output = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=self.t5_max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True, # This is important for T5 to add the EOS token.
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = tokenized_output.input_ids[0]
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
# The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces.
|
||||
# We'll replace it with a regular space for readability.
|
||||
tokens = [t.replace(" ", " ") for t in tokens]
|
||||
|
||||
tokenized_str = ""
|
||||
used_tokens = 0
|
||||
for token in tokens:
|
||||
if token == tokenizer.eos_token:
|
||||
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
|
||||
used_tokens += 1
|
||||
elif token == tokenizer.pad_token:
|
||||
# tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD
|
||||
continue
|
||||
else:
|
||||
color = (used_tokens % 6) + 1 # Cycle through 6 colors
|
||||
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
|
||||
used_tokens += 1
|
||||
|
||||
context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):")
|
||||
context.logger.info(f"{tokenized_str}\x1b[0m")
|
||||
|
||||
def _log_clip_tokenization(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
tokenizer: CLIPTokenizer,
|
||||
) -> None:
|
||||
"""Logs the tokenization of a prompt for a CLIP-based model."""
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
tokenized_output = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = tokenized_output.input_ids[0]
|
||||
attention_mask = tokenized_output.attention_mask[0]
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
# The CLIP tokenizer uses '</w>' to denote spaces.
|
||||
# We'll replace it with a regular space for readability.
|
||||
tokens = [t.replace("</w>", " ") for t in tokens]
|
||||
|
||||
tokenized_str = ""
|
||||
used_tokens = 0
|
||||
for i, token in enumerate(tokens):
|
||||
if attention_mask[i] == 0:
|
||||
# Do not log padding tokens.
|
||||
continue
|
||||
|
||||
if token == tokenizer.bos_token:
|
||||
tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS
|
||||
elif token == tokenizer.eos_token:
|
||||
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
|
||||
else:
|
||||
color = (used_tokens % 6) + 1 # Cycle through 6 colors
|
||||
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
|
||||
used_tokens += 1
|
||||
|
||||
context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):")
|
||||
context.logger.info(f"{tokenized_str}\x1b[0m")
|
||||
|
||||
Reference in New Issue
Block a user