Files
tinygrad/test/external/external_test_simple_tokenizer.py
George Hotz b9eb5b5d49 clean up the LLM tokenizer (#12653)
* clean up the LLM tokenizer

* simple tokenizer is actually simple

* ugh write good code
2025-10-14 14:22:01 +08:00

51 lines
2.1 KiB
Python

import functools, multiprocessing
from transformers import AutoTokenizer
from datasets import load_dataset
from tinygrad.apps.llm import SimpleTokenizer
from tinygrad.helpers import tqdm, getenv, partition
@functools.cache
def get_tokenizers():
print("getting tokenizers")
base_tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
special_tokens, normal_tokens = partition(((t, tid) for t, tid in base_tokenizer.vocab.items()), lambda e: e[1] in base_tokenizer.all_special_ids)
simple_tokenizer = SimpleTokenizer(dict(normal_tokens), dict(special_tokens))
return base_tokenizer, simple_tokenizer
def test_tokenize(samp) -> bool:
base_tokenizer, simple_tokenizer = get_tokenizers()
idx, txt = samp
try: simple_tokens = tuple(simple_tokenizer.encode(txt))
except RuntimeError: simple_tokens = ()
base_tokens = tuple(base_tokenizer.encode(txt, add_special_tokens=False))
if simple_tokens != base_tokens:
print(f"tokens mismatch at index: {idx}.\n")
color_codes = [91, 92, 94, 93, 95]
def color_tokens(tids):
return "".join(f"\033[{color_codes[i%len(color_codes)]}m{base_tokenizer.decode([t])}" for i, t in enumerate(tids)) + "\033[0m"
print("simple: ", color_tokens(simple_tokens))
print("official:", color_tokens(base_tokens) + "\n")
return False
if simple_tokenizer.decode(simple_tokens) != txt:
print(f"decode mismatch at {idx}")
return False
return True
# use ALLOW_FAILED=-1 to go over the entire dataset without printing.
if __name__ == "__main__":
print("loading datasets")
ds = load_dataset("OpenAssistant/oasst1")
loaded_ds = [(idx, el["text"]) for idx, el in enumerate(ds["train"])]
print(f"loaded {len(loaded_ds)}")
allow_failed = getenv("ALLOW_FAILED", 10)
fail_count, total = 0, 0
with multiprocessing.Pool(16) as pool:
for good in tqdm(pool.imap_unordered(test_tokenize, loaded_ds), total=len(loaded_ds)):
total += 1
if not good:
fail_count += 1
allow_failed -= 1
if allow_failed == 0: break
print(f"{fail_count}/{total} samples are inconsistent with the official tokenizer.")