From fef632e0e1786b2129c606f85608887a2a1f3234 Mon Sep 17 00:00:00 2001 From: xra Date: Mon, 29 Aug 2022 12:28:49 +0900 Subject: [PATCH] tokenization logging (take 2) This adds an option -t argument that will print out color-coded tokenization, SD has a maximum of 77 tokens, it silently discards tokens over the limit if your prompt is too long. By using -t you can see how your prompt is being tokenized which helps prompt crafting. --- ldm/simplet2i.py | 28 ++++++++++++++++++++++++++++ scripts/dream.py | 6 ++++++ 2 files changed, 34 insertions(+) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index f1aec32c29..3d24889b53 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -213,6 +213,7 @@ class T2I: upscale=None, variants=None, sampler_name=None, + log_tokenization=False, **args, ): # eat up additional cruft """ @@ -253,6 +254,7 @@ class T2I: batch_size = batch_size or self.batch_size iterations = iterations or self.iterations strength = strength or self.strength + self.log_tokenization = log_tokenization model = ( self.load_model() @@ -489,6 +491,7 @@ class T2I: weight = weights[i] if not skip_normalize: weight = weight / totalWeight + self._log_tokenization(subprompts[i]) c = torch.add( c, self.model.get_learned_conditioning( @@ -497,6 +500,7 @@ class T2I: alpha=weight, ) else: # just standard 1 prompt + self._log_tokenization(prompt) c = self.model.get_learned_conditioning(batch_size * [prompt]) return (uc, c) @@ -657,3 +661,27 @@ class T2I: weights.append(1.0) remaining = 0 return prompts, weights + + # shows how the prompt is tokenized + # usually tokens have '' to indicate end-of-word, + # but for readability it has been replaced with ' ' + def _log_tokenization(self, text): + if not self.log_tokenization: + return + tokens = self.model.cond_stage_model.tokenizer._tokenize(text) + tokenized = "" + discarded = "" + usedTokens = 0 + totalTokens = len(tokens) + for i in range(0,totalTokens): + token = tokens[i].replace('',' ') + # alternate color + s = (usedTokens % 6) + 1 + if i < self.model.cond_stage_model.max_length: + tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + usedTokens += 1 + else: # over max token length + discarded = discarded + f"\x1b[0;3{s};40m{token}" + print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m") + if discarded != "": + print(f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m") diff --git a/scripts/dream.py b/scripts/dream.py index e48b2e64a8..b429c46649 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -462,6 +462,12 @@ def create_cmd_parser(): metavar='SAMPLER_NAME', help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}', ) + parser.add_argument( + '-t', + '--log_tokenization', + action='store_true', + help='shows how the prompt is split into tokens' + ) return parser