BPE tokenizer (#11415)

* BPE works

* refactor tok

* oops

* basic tests

* fix eval

* smaller diff

* fix error

* proper vocab decoding

* use regex for splitting

* escape ucatrange

* full compat

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
leopf
2025-08-04 18:52:38 +02:00
committed by GitHub
parent 06af9f9236
commit 4f0ee4e982
4 changed files with 114 additions and 31 deletions

View File

@@ -10,7 +10,7 @@ if __name__ == "__main__":
model, kv = Transformer.from_gguf(Tensor.from_url(models["1B"]), max_context=4096)
tok = SimpleTokenizer(kv["tokenizer.ggml.tokens"])
tok = SimpleTokenizer.from_gguf_kv(kv)
bos_id: int = kv['tokenizer.ggml.bos_token_id']
eos_id: int = kv['tokenizer.ggml.eos_token_id']

View File

@@ -1,17 +1,19 @@
from transformers import AutoTokenizer
from datasets import load_dataset
from tinygrad.apps.llm import SimpleTokenizer
from tinygrad.helpers import tqdm, getenv
from tinygrad.apps.llm import SimpleTokenizer, gpt2_decode_vocab, get_llama_re
from tinygrad.helpers import tqdm, getenv, partition
# use ALLOW_FAILED=-1 to go over the entire dataset without printing.
if __name__ == "__main__":
base_tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
vocab_words = [ word for word, _ in sorted(base_tokenizer.get_vocab().items(), key=lambda t: t[1]) ]
special_tokens, normal_tokens = partition(((t, tid) for t, tid in base_tokenizer.vocab.items()),
lambda e: e[1] in base_tokenizer.all_special_ids)
inv_vocab = { tid: word for word, tid in base_tokenizer.get_vocab().items() }
simple_tokenizer = SimpleTokenizer(vocab_words)
simple_tokenizer = SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens))
color_codes = [ 91, 92, 94, 93, 95 ]
def color_tokens(tids): return "".join(f"\033[{color_codes[i%len(color_codes)]}m{inv_vocab[t]}" for i, t in enumerate(tids)) + "\033[0m"
def color_tokens(tids):
return "".join(f"\033[{color_codes[i%len(color_codes)]}m{base_tokenizer.decode([t])}" for i, t in enumerate(tids)) + "\033[0m"
ds = load_dataset("OpenAssistant/oasst1")
allow_failed = getenv("ALLOW_FAILED", 10)

View File

@@ -0,0 +1,57 @@
import unittest, base64, functools
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re
from tinygrad.helpers import fetch
class TestLLMTokenizer(unittest.TestCase):
@functools.cached_property
def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "<x>": 5, "<y>": 6, "<z>": 7 })
@functools.cached_property
def llama_tok(self):
# from https://github.com/tinygrad/tinygrad/blob/e0106b6b257ebc003eb3694144e3e198f7d8cc37/examples/llama3.py#L14
model_file = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model")
with open(model_file, "rt") as fd:
str_vocab = [ line.split(maxsplit=1) for line in fd.read().splitlines() if line ]
normal_tokens = { base64.b64decode(stok): int(srank) for stok, srank in str_vocab }
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>",
] + [ f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5) ]
return SimpleTokenizer(get_llama_re(), normal_tokens, { token: len(normal_tokens) + i for i, token in enumerate(special_tokens) })
def _test_coding(self, tok: SimpleTokenizer, text: str, expected_tokens: list[int]):
self.assertEqual(tok.encode(text), expected_tokens)
self.assertEqual(tok.decode(expected_tokens), text)
def test_abc(self): self._test_coding(self.basic_tok, "abc", [ 3, 2 ])
def test_abbc(self): self._test_coding(self.basic_tok, "abbc", [ 3, 4 ])
def test_aabbbcc(self): self._test_coding(self.basic_tok, "aabbbcc", [ 0, 3, 1, 4, 2 ])
def test_specials1(self): self._test_coding(self.basic_tok, "a<x>a<y>a<z>a", [ 0, 5, 0, 6, 0, 7, 0 ])
def test_specials2(self): self._test_coding(self.basic_tok, "<x>a<y>a<z>", [ 5, 0, 6, 0, 7 ])
def test_invalid_token(self):
with self.assertRaises(RuntimeError): self._test_coding(self.basic_tok, "L", [])
def test_no_specials(self): self._test_coding(SimpleTokenizer(".*", { bytes([i]): i for i in range(256) }, {}), "abc", [97, 98, 99])
# NOTE: the correct tokenization for this can only be found by looking up the text chunk in the vocab, not by applying merges
def test_llama_early_tokenize(self): self._test_coding(self.llama_tok, " например", [ 111797 ])
def test_llama_basic(self): self._test_coding(self.llama_tok, "hello world", [ 15339, 1917 ])
def test_llama_control_char(self): self._test_coding(self.llama_tok, " \x850", [ 220, 116360, 15 ])
def test_llama_bytes(self): self._test_coding(self.llama_tok, " \xec\x8b\xa4\xed", [ 1717, 105, 116174, 82638, 2483 ])
def test_llama_special1(self): self._test_coding(self.llama_tok, "hello <|end_of_text|>", [ 15339, 220, 128001 ])
def test_llama_special2(self): self._test_coding(self.llama_tok, "<|start_header_id|>user<|end_header_id|>\n\n", [ 128006, 882, 128007, 271 ])
def test_llama_repeat(self): self._test_coding(self.llama_tok, "00000000000000000", [ 931, 931, 931, 931, 931, 410 ])
def test_llama_pat(self): self._test_coding(self.llama_tok, "today\n \n", [ 31213, 14211 ])
if __name__ == '__main__':
unittest.main()

View File

@@ -1,33 +1,57 @@
from __future__ import annotations
import sys, argparse
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
import sys, argparse, typing, re, itertools, unicodedata
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers
def gpt2_decode_vocab(voc: dict[str, int]): # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
c2b = { chr(cp): cp for cp in itertools.chain(range(ord("!"), ord("~")+1), range(ord("¡"), ord("¬")+1), range(ord("®"), ord("ÿ")+1)) }
c2b.update({ chr(256+off): cp for off, cp in enumerate(cp for cp in range(256) if chr(cp) not in c2b) })
return { bytes(c2b[c] for c in tok): tid for tok, tid in voc.items() }
def get_llama_re():
def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre))
r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L")
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286
return "(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \
f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+"
class SimpleTokenizer:
def __init__(self, vocab: list[str]):
self.vocab: list[str] = vocab
self.biggest_token: int = max(map(len, vocab))
self.token_to_id: dict[str, int] = {tok: i for i, tok in enumerate(vocab)}
self.replace_space = "Ġ"
self.replace_newline = "Ċ"
def __init__(self, pat: str, normal_tokens: dict[bytes, int], special_tokens: dict[str, int]):
self._normal_tokens, self._special_tokens, self._pat = normal_tokens, special_tokens, re.compile(pat)
self._tok2str = { tid: tok.encode() for tok, tid in special_tokens.items() } | { tid: tok for tok, tid in normal_tokens.items() }
self._special_re = re.compile("|".join(re.escape(tok) for tok in self._special_tokens.keys()) if special_tokens else r"(?!)")
def encode(self, text:str) -> list[int]:
s = text.replace(" ", self.replace_space).replace("\n", self.replace_newline)
out: list[int] = []
i = 0
while i < len(s):
j = min(i+self.biggest_token, len(s))
while i < j and (tid:=self.token_to_id.get(s[i:j])) is None: j -= 1
if tid is None: raise RuntimeError(f"token not found in {s}")
assert tid is not None, f"token not found in {s}"
out.append(tid)
i = j
return out
@staticmethod
def from_gguf_kv(kv: dict):
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
return SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens))
def decode(self, ids: list[int]) -> str:
return ''.join(self.vocab[tid] for tid in ids).replace(self.replace_space, " ").replace(self.replace_newline, "\n")
def encode(self, text: str):
tokens: list[int] = []
pos = 0
for match in self._special_re.finditer(text):
tokens.extend(self._encode_sentence(text[pos:match.start(0)]) + [self._special_tokens[text[match.start(0):match.end(0)]]])
pos = match.end(0)
return tokens + self._encode_sentence(text[pos:])
def role(self, role:str):
return [t for x in ["<|start_header_id|>", role, "<|end_header_id|>\n\n"] for t in self.encode(x)] # llama style
def decode(self, ids: list[int]) -> str: return b''.join(self._tok2str[tid] for tid in ids).decode()
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
def _encode_sentence(self, chunk: str): return [ tok for word in self._pat.findall(chunk) for tok in self._encode_word(word.encode()) ]
def _encode_word(self, word: bytes):
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
parts = [word[i:i+1] for i in range(len(word))]
while True:
min_tid, min_idx = 2**32, -1
for idx, (p1, p2) in enumerate(zip(parts[:-1], parts[1:])):
tid = self._normal_tokens.get(p1 + p2, min_tid)
if tid < min_tid: min_tid, min_idx = tid, idx
if min_idx == -1: break
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx+1]] + parts[min_idx+2:]
try: return [ self._normal_tokens[p] for p in parts ]
except KeyError: raise RuntimeError("token not found")
def apply_rope(x:Tensor, start_pos:int|UOp, base:int=10000):
B, H, T, Hd = x.shape
@@ -165,7 +189,7 @@ if __name__ == "__main__":
model, kv = Transformer.from_gguf(Tensor.from_url(models[args.size]), args.max_context)
# extract some metadata
tok = SimpleTokenizer(kv["tokenizer.ggml.tokens"])
tok = SimpleTokenizer.from_gguf_kv(kv)
bos_id: int = kv['tokenizer.ggml.bos_token_id']
eos_id: int = kv['tokenizer.ggml.eos_token_id']