From b9eb5b5d49e7b36219f4fb1646cd33ffa4161eae Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 14 Oct 2025 14:22:01 +0800 Subject: [PATCH] clean up the LLM tokenizer (#12653) * clean up the LLM tokenizer * simple tokenizer is actually simple * ugh write good code --- .../external_test_simple_tokenizer.py | 73 +++++++++++-------- test/unit/test_llm_tokenizer.py | 26 +++---- tinygrad/apps/llm.py | 69 +++++++++--------- 3 files changed, 83 insertions(+), 85 deletions(-) diff --git a/test/external/external_test_simple_tokenizer.py b/test/external/external_test_simple_tokenizer.py index 9c3ca8f420..8fc3299ee1 100644 --- a/test/external/external_test_simple_tokenizer.py +++ b/test/external/external_test_simple_tokenizer.py @@ -1,41 +1,50 @@ +import functools, multiprocessing from transformers import AutoTokenizer from datasets import load_dataset -from tinygrad.apps.llm import SimpleTokenizer, gpt2_decode_vocab, get_llama_re +from tinygrad.apps.llm import SimpleTokenizer from tinygrad.helpers import tqdm, getenv, partition +@functools.cache +def get_tokenizers(): + print("getting tokenizers") + base_tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") + special_tokens, normal_tokens = partition(((t, tid) for t, tid in base_tokenizer.vocab.items()), lambda e: e[1] in base_tokenizer.all_special_ids) + simple_tokenizer = SimpleTokenizer(dict(normal_tokens), dict(special_tokens)) + return base_tokenizer, simple_tokenizer + +def test_tokenize(samp) -> bool: + base_tokenizer, simple_tokenizer = get_tokenizers() + idx, txt = samp + try: simple_tokens = tuple(simple_tokenizer.encode(txt)) + except RuntimeError: simple_tokens = () + base_tokens = tuple(base_tokenizer.encode(txt, add_special_tokens=False)) + if simple_tokens != base_tokens: + print(f"tokens mismatch at index: {idx}.\n") + color_codes = [91, 92, 94, 93, 95] + def color_tokens(tids): + return "".join(f"\033[{color_codes[i%len(color_codes)]}m{base_tokenizer.decode([t])}" for i, t in enumerate(tids)) + "\033[0m" + print("simple: ", color_tokens(simple_tokens)) + print("official:", color_tokens(base_tokens) + "\n") + return False + if simple_tokenizer.decode(simple_tokens) != txt: + print(f"decode mismatch at {idx}") + return False + return True + # use ALLOW_FAILED=-1 to go over the entire dataset without printing. if __name__ == "__main__": - base_tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") - special_tokens, normal_tokens = partition(((t, tid) for t, tid in base_tokenizer.vocab.items()), - lambda e: e[1] in base_tokenizer.all_special_ids) - inv_vocab = { tid: word for word, tid in base_tokenizer.get_vocab().items() } - simple_tokenizer = SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens)) - - color_codes = [ 91, 92, 94, 93, 95 ] - def color_tokens(tids): - return "".join(f"\033[{color_codes[i%len(color_codes)]}m{base_tokenizer.decode([t])}" for i, t in enumerate(tids)) + "\033[0m" - + print("loading datasets") ds = load_dataset("OpenAssistant/oasst1") + loaded_ds = [(idx, el["text"]) for idx, el in enumerate(ds["train"])] + print(f"loaded {len(loaded_ds)}") + allow_failed = getenv("ALLOW_FAILED", 10) - fail_count, total = 0, 0 - - for idx, el in enumerate(tqdm(ds["train"])): - total += 1 - - try: simple_tokens = tuple(simple_tokenizer.encode(el["text"])) - except RuntimeError: simple_tokens = () - base_tokens = tuple(base_tokenizer.encode(el["text"], add_special_tokens=False)) - - if simple_tokens != base_tokens: - fail_count += 1 - allow_failed -= 1 - - if allow_failed >= 0: - print(f"tokens mismatch at index: {idx}.\n") - - print("simple: ", color_tokens(simple_tokens)) - print("official:", color_tokens(base_tokens) + "\n") - - if allow_failed == 0: break - print(f"{fail_count}/{total} samples are inconsistent with the official tokenizer.") + with multiprocessing.Pool(16) as pool: + for good in tqdm(pool.imap_unordered(test_tokenize, loaded_ds), total=len(loaded_ds)): + total += 1 + if not good: + fail_count += 1 + allow_failed -= 1 + if allow_failed == 0: break + print(f"{fail_count}/{total} samples are inconsistent with the official tokenizer.") diff --git a/test/unit/test_llm_tokenizer.py b/test/unit/test_llm_tokenizer.py index 7b65818a6f..1e7f6cb48a 100644 --- a/test/unit/test_llm_tokenizer.py +++ b/test/unit/test_llm_tokenizer.py @@ -1,19 +1,21 @@ import unittest, base64, functools, sys -from tinygrad.apps.llm import SimpleTokenizer, get_llama_re +from tinygrad.apps.llm import SimpleTokenizer from tinygrad.helpers import fetch @unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows") class TestLLMTokenizer(unittest.TestCase): - @functools.cached_property - def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "": 5, "": 6, "": 7 }) - @functools.cached_property def llama_tok(self): # from https://github.com/tinygrad/tinygrad/blob/e0106b6b257ebc003eb3694144e3e198f7d8cc37/examples/llama3.py#L14 model_file = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model") with open(model_file, "rt") as fd: - str_vocab = [ line.split(maxsplit=1) for line in fd.read().splitlines() if line ] - normal_tokens = { base64.b64decode(stok): int(srank) for stok, srank in str_vocab } + str_vocab = [line.split(maxsplit=1) for line in fd.read().splitlines() if line] + + # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9 + bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves + _byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)} + _byte_encoder = {v:k for k,v in _byte_decoder.items()} + normal_tokens = {''.join([_byte_encoder[x] for x in base64.b64decode(stok)]): int(srank) for stok, srank in str_vocab} special_tokens = [ "<|begin_of_text|>", @@ -27,22 +29,12 @@ class TestLLMTokenizer(unittest.TestCase): "<|reserved_special_token_4|>", "<|eot_id|>", ] + [ f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5) ] - return SimpleTokenizer(get_llama_re(), normal_tokens, { token: len(normal_tokens) + i for i, token in enumerate(special_tokens) }) + return SimpleTokenizer(normal_tokens, {token: len(normal_tokens) + i for i, token in enumerate(special_tokens)}) def _test_coding(self, tok: SimpleTokenizer, text: str, expected_tokens: list[int]): self.assertEqual(tok.encode(text), expected_tokens) self.assertEqual(tok.decode(expected_tokens), text) - def test_abc(self): self._test_coding(self.basic_tok, "abc", [ 3, 2 ]) - def test_abbc(self): self._test_coding(self.basic_tok, "abbc", [ 3, 4 ]) - def test_aabbbcc(self): self._test_coding(self.basic_tok, "aabbbcc", [ 0, 3, 1, 4, 2 ]) - def test_specials1(self): self._test_coding(self.basic_tok, "aaaa", [ 0, 5, 0, 6, 0, 7, 0 ]) - def test_specials2(self): self._test_coding(self.basic_tok, "aa", [ 5, 0, 6, 0, 7 ]) - def test_invalid_token(self): - with self.assertRaises(RuntimeError): self._test_coding(self.basic_tok, "L", []) - - def test_no_specials(self): self._test_coding(SimpleTokenizer(".*", { bytes([i]): i for i in range(256) }, {}), "abc", [97, 98, 99]) - # NOTE: the correct tokenization for this can only be found by looking up the text chunk in the vocab, not by applying merges def test_llama_early_tokenize(self): self._test_coding(self.llama_tok, " например", [ 111797 ]) diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 77fa753ec0..df0d6d6db7 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -1,58 +1,55 @@ from __future__ import annotations -import sys, argparse, typing, re, itertools, unicodedata +import sys, argparse, typing, re, unicodedata from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers -def gpt2_decode_vocab(voc: dict[str, int]): # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9 - c2b = { chr(cp): cp for cp in itertools.chain(range(ord("!"), ord("~")+1), range(ord("¡"), ord("¬")+1), range(ord("®"), ord("ÿ")+1)) } - c2b.update({ chr(256+off): cp for off, cp in enumerate(cp for cp in range(256) if chr(cp) not in c2b) }) - return { bytes(c2b[c] for c in tok): tid for tok, tid in voc.items() } - -def get_llama_re(): - def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre)) - r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L") - # https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286 - return "(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \ - f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+" - class SimpleTokenizer: - def __init__(self, pat: str, normal_tokens: dict[bytes, int], special_tokens: dict[str, int]): - self._normal_tokens, self._special_tokens, self._pat = normal_tokens, special_tokens, re.compile(pat) - self._tok2str = { tid: tok.encode() for tok, tid in special_tokens.items() } | { tid: tok for tok, tid in normal_tokens.items() } - self._special_re = re.compile("|".join(re.escape(tok) for tok in self._special_tokens.keys()) if special_tokens else r"(?!)") + def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]): + # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9 + bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves + self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)} + + # https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286 + def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre)) + r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L") + self._split_to_word = re.compile("(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \ + f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+") + self._split_to_sentence = re.compile("|".join(re.escape(tok) for tok in special_tokens.keys()) if special_tokens else r"(?!)") + + self._normal_tokens = {bytes(self._byte_decoder[c] for c in tok): tid for tok, tid in normal_tokens.items()} + self._special_tokens = special_tokens + self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()} @staticmethod - def from_gguf_kv(kv: dict): + def from_gguf_kv(kv:dict): # https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820 if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'") vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"])) normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1) - return SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens)) + return SimpleTokenizer(dict(normal_tokens), dict(special_tokens)) - def encode(self, text: str): + def _encode_word(self, word:bytes) -> list[int]: + if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token] + parts = [bytes([b]) for b in word] + # greedily merge any parts that we can + while True: + i = min([(sys.maxsize, -1)] + [(self._normal_tokens.get(parts[j]+parts[j+1], sys.maxsize), j) for j in range(len(parts)-1)])[1] + if i == -1: break + parts[i:i+2] = [parts[i] + parts[i+1]] + try: return [self._normal_tokens[p] for p in parts] + except KeyError: raise RuntimeError("token not found") + def _encode_sentence(self, chunk:str) -> list[int]: + return [tok for word in self._split_to_word.findall(chunk) for tok in self._encode_word(word.encode())] + def encode(self, text:str) -> list[int]: tokens: list[int] = [] pos = 0 - for match in self._special_re.finditer(text): + for match in self._split_to_sentence.finditer(text): tokens.extend(self._encode_sentence(text[pos:match.start(0)]) + [self._special_tokens[text[match.start(0):match.end(0)]]]) pos = match.end(0) return tokens + self._encode_sentence(text[pos:]) - def decode(self, ids: list[int]) -> str: return b''.join(self._tok2str[tid] for tid in ids).decode() + def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode() def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n") - def _encode_sentence(self, chunk: str): return [ tok for word in self._pat.findall(chunk) for tok in self._encode_word(word.encode()) ] - def _encode_word(self, word: bytes): - if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token] - parts = [word[i:i+1] for i in range(len(word))] - while True: - min_tid, min_idx = 2**32, -1 - for idx, (p1, p2) in enumerate(zip(parts[:-1], parts[1:])): - tid = self._normal_tokens.get(p1 + p2, min_tid) - if tid < min_tid: min_tid, min_idx = tid, idx - if min_idx == -1: break - parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx+1]] + parts[min_idx+2:] - try: return [ self._normal_tokens[p] for p in parts ] - except KeyError: raise RuntimeError("token not found") - def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor: B, H, T, Hd = x.shape assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension"