mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
clean up the LLM tokenizer (#12653)
* clean up the LLM tokenizer * simple tokenizer is actually simple * ugh write good code
This commit is contained in:
73
test/external/external_test_simple_tokenizer.py
vendored
73
test/external/external_test_simple_tokenizer.py
vendored
@@ -1,41 +1,50 @@
|
||||
import functools, multiprocessing
|
||||
from transformers import AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
from tinygrad.apps.llm import SimpleTokenizer, gpt2_decode_vocab, get_llama_re
|
||||
from tinygrad.apps.llm import SimpleTokenizer
|
||||
from tinygrad.helpers import tqdm, getenv, partition
|
||||
|
||||
@functools.cache
|
||||
def get_tokenizers():
|
||||
print("getting tokenizers")
|
||||
base_tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||
special_tokens, normal_tokens = partition(((t, tid) for t, tid in base_tokenizer.vocab.items()), lambda e: e[1] in base_tokenizer.all_special_ids)
|
||||
simple_tokenizer = SimpleTokenizer(dict(normal_tokens), dict(special_tokens))
|
||||
return base_tokenizer, simple_tokenizer
|
||||
|
||||
def test_tokenize(samp) -> bool:
|
||||
base_tokenizer, simple_tokenizer = get_tokenizers()
|
||||
idx, txt = samp
|
||||
try: simple_tokens = tuple(simple_tokenizer.encode(txt))
|
||||
except RuntimeError: simple_tokens = ()
|
||||
base_tokens = tuple(base_tokenizer.encode(txt, add_special_tokens=False))
|
||||
if simple_tokens != base_tokens:
|
||||
print(f"tokens mismatch at index: {idx}.\n")
|
||||
color_codes = [91, 92, 94, 93, 95]
|
||||
def color_tokens(tids):
|
||||
return "".join(f"\033[{color_codes[i%len(color_codes)]}m{base_tokenizer.decode([t])}" for i, t in enumerate(tids)) + "\033[0m"
|
||||
print("simple: ", color_tokens(simple_tokens))
|
||||
print("official:", color_tokens(base_tokens) + "\n")
|
||||
return False
|
||||
if simple_tokenizer.decode(simple_tokens) != txt:
|
||||
print(f"decode mismatch at {idx}")
|
||||
return False
|
||||
return True
|
||||
|
||||
# use ALLOW_FAILED=-1 to go over the entire dataset without printing.
|
||||
if __name__ == "__main__":
|
||||
base_tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||
special_tokens, normal_tokens = partition(((t, tid) for t, tid in base_tokenizer.vocab.items()),
|
||||
lambda e: e[1] in base_tokenizer.all_special_ids)
|
||||
inv_vocab = { tid: word for word, tid in base_tokenizer.get_vocab().items() }
|
||||
simple_tokenizer = SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens))
|
||||
|
||||
color_codes = [ 91, 92, 94, 93, 95 ]
|
||||
def color_tokens(tids):
|
||||
return "".join(f"\033[{color_codes[i%len(color_codes)]}m{base_tokenizer.decode([t])}" for i, t in enumerate(tids)) + "\033[0m"
|
||||
|
||||
print("loading datasets")
|
||||
ds = load_dataset("OpenAssistant/oasst1")
|
||||
loaded_ds = [(idx, el["text"]) for idx, el in enumerate(ds["train"])]
|
||||
print(f"loaded {len(loaded_ds)}")
|
||||
|
||||
allow_failed = getenv("ALLOW_FAILED", 10)
|
||||
|
||||
fail_count, total = 0, 0
|
||||
|
||||
for idx, el in enumerate(tqdm(ds["train"])):
|
||||
total += 1
|
||||
|
||||
try: simple_tokens = tuple(simple_tokenizer.encode(el["text"]))
|
||||
except RuntimeError: simple_tokens = ()
|
||||
base_tokens = tuple(base_tokenizer.encode(el["text"], add_special_tokens=False))
|
||||
|
||||
if simple_tokens != base_tokens:
|
||||
fail_count += 1
|
||||
allow_failed -= 1
|
||||
|
||||
if allow_failed >= 0:
|
||||
print(f"tokens mismatch at index: {idx}.\n")
|
||||
|
||||
print("simple: ", color_tokens(simple_tokens))
|
||||
print("official:", color_tokens(base_tokens) + "\n")
|
||||
|
||||
if allow_failed == 0: break
|
||||
print(f"{fail_count}/{total} samples are inconsistent with the official tokenizer.")
|
||||
with multiprocessing.Pool(16) as pool:
|
||||
for good in tqdm(pool.imap_unordered(test_tokenize, loaded_ds), total=len(loaded_ds)):
|
||||
total += 1
|
||||
if not good:
|
||||
fail_count += 1
|
||||
allow_failed -= 1
|
||||
if allow_failed == 0: break
|
||||
print(f"{fail_count}/{total} samples are inconsistent with the official tokenizer.")
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
import unittest, base64, functools, sys
|
||||
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re
|
||||
from tinygrad.apps.llm import SimpleTokenizer
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows")
|
||||
class TestLLMTokenizer(unittest.TestCase):
|
||||
@functools.cached_property
|
||||
def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "<x>": 5, "<y>": 6, "<z>": 7 })
|
||||
|
||||
@functools.cached_property
|
||||
def llama_tok(self):
|
||||
# from https://github.com/tinygrad/tinygrad/blob/e0106b6b257ebc003eb3694144e3e198f7d8cc37/examples/llama3.py#L14
|
||||
model_file = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model")
|
||||
with open(model_file, "rt") as fd:
|
||||
str_vocab = [ line.split(maxsplit=1) for line in fd.read().splitlines() if line ]
|
||||
normal_tokens = { base64.b64decode(stok): int(srank) for stok, srank in str_vocab }
|
||||
str_vocab = [line.split(maxsplit=1) for line in fd.read().splitlines() if line]
|
||||
|
||||
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
||||
_byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||
_byte_encoder = {v:k for k,v in _byte_decoder.items()}
|
||||
normal_tokens = {''.join([_byte_encoder[x] for x in base64.b64decode(stok)]): int(srank) for stok, srank in str_vocab}
|
||||
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
@@ -27,22 +29,12 @@ class TestLLMTokenizer(unittest.TestCase):
|
||||
"<|reserved_special_token_4|>",
|
||||
"<|eot_id|>",
|
||||
] + [ f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5) ]
|
||||
return SimpleTokenizer(get_llama_re(), normal_tokens, { token: len(normal_tokens) + i for i, token in enumerate(special_tokens) })
|
||||
return SimpleTokenizer(normal_tokens, {token: len(normal_tokens) + i for i, token in enumerate(special_tokens)})
|
||||
|
||||
def _test_coding(self, tok: SimpleTokenizer, text: str, expected_tokens: list[int]):
|
||||
self.assertEqual(tok.encode(text), expected_tokens)
|
||||
self.assertEqual(tok.decode(expected_tokens), text)
|
||||
|
||||
def test_abc(self): self._test_coding(self.basic_tok, "abc", [ 3, 2 ])
|
||||
def test_abbc(self): self._test_coding(self.basic_tok, "abbc", [ 3, 4 ])
|
||||
def test_aabbbcc(self): self._test_coding(self.basic_tok, "aabbbcc", [ 0, 3, 1, 4, 2 ])
|
||||
def test_specials1(self): self._test_coding(self.basic_tok, "a<x>a<y>a<z>a", [ 0, 5, 0, 6, 0, 7, 0 ])
|
||||
def test_specials2(self): self._test_coding(self.basic_tok, "<x>a<y>a<z>", [ 5, 0, 6, 0, 7 ])
|
||||
def test_invalid_token(self):
|
||||
with self.assertRaises(RuntimeError): self._test_coding(self.basic_tok, "L", [])
|
||||
|
||||
def test_no_specials(self): self._test_coding(SimpleTokenizer(".*", { bytes([i]): i for i in range(256) }, {}), "abc", [97, 98, 99])
|
||||
|
||||
# NOTE: the correct tokenization for this can only be found by looking up the text chunk in the vocab, not by applying merges
|
||||
def test_llama_early_tokenize(self): self._test_coding(self.llama_tok, " например", [ 111797 ])
|
||||
|
||||
|
||||
@@ -1,58 +1,55 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, itertools, unicodedata
|
||||
import sys, argparse, typing, re, unicodedata
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers
|
||||
|
||||
def gpt2_decode_vocab(voc: dict[str, int]): # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||
c2b = { chr(cp): cp for cp in itertools.chain(range(ord("!"), ord("~")+1), range(ord("¡"), ord("¬")+1), range(ord("®"), ord("ÿ")+1)) }
|
||||
c2b.update({ chr(256+off): cp for off, cp in enumerate(cp for cp in range(256) if chr(cp) not in c2b) })
|
||||
return { bytes(c2b[c] for c in tok): tid for tok, tid in voc.items() }
|
||||
|
||||
def get_llama_re():
|
||||
def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre))
|
||||
r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L")
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286
|
||||
return "(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \
|
||||
f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+"
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, pat: str, normal_tokens: dict[bytes, int], special_tokens: dict[str, int]):
|
||||
self._normal_tokens, self._special_tokens, self._pat = normal_tokens, special_tokens, re.compile(pat)
|
||||
self._tok2str = { tid: tok.encode() for tok, tid in special_tokens.items() } | { tid: tok for tok, tid in normal_tokens.items() }
|
||||
self._special_re = re.compile("|".join(re.escape(tok) for tok in self._special_tokens.keys()) if special_tokens else r"(?!)")
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
|
||||
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
||||
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286
|
||||
def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre))
|
||||
r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L")
|
||||
self._split_to_word = re.compile("(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \
|
||||
f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+")
|
||||
self._split_to_sentence = re.compile("|".join(re.escape(tok) for tok in special_tokens.keys()) if special_tokens else r"(?!)")
|
||||
|
||||
self._normal_tokens = {bytes(self._byte_decoder[c] for c in tok): tid for tok, tid in normal_tokens.items()}
|
||||
self._special_tokens = special_tokens
|
||||
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
|
||||
|
||||
@staticmethod
|
||||
def from_gguf_kv(kv: dict):
|
||||
def from_gguf_kv(kv:dict):
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
||||
if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
|
||||
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
||||
normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||
return SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens))
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens))
|
||||
|
||||
def encode(self, text: str):
|
||||
def _encode_word(self, word:bytes) -> list[int]:
|
||||
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
|
||||
parts = [bytes([b]) for b in word]
|
||||
# greedily merge any parts that we can
|
||||
while True:
|
||||
i = min([(sys.maxsize, -1)] + [(self._normal_tokens.get(parts[j]+parts[j+1], sys.maxsize), j) for j in range(len(parts)-1)])[1]
|
||||
if i == -1: break
|
||||
parts[i:i+2] = [parts[i] + parts[i+1]]
|
||||
try: return [self._normal_tokens[p] for p in parts]
|
||||
except KeyError: raise RuntimeError("token not found")
|
||||
def _encode_sentence(self, chunk:str) -> list[int]:
|
||||
return [tok for word in self._split_to_word.findall(chunk) for tok in self._encode_word(word.encode())]
|
||||
def encode(self, text:str) -> list[int]:
|
||||
tokens: list[int] = []
|
||||
pos = 0
|
||||
for match in self._special_re.finditer(text):
|
||||
for match in self._split_to_sentence.finditer(text):
|
||||
tokens.extend(self._encode_sentence(text[pos:match.start(0)]) + [self._special_tokens[text[match.start(0):match.end(0)]]])
|
||||
pos = match.end(0)
|
||||
return tokens + self._encode_sentence(text[pos:])
|
||||
|
||||
def decode(self, ids: list[int]) -> str: return b''.join(self._tok2str[tid] for tid in ids).decode()
|
||||
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode()
|
||||
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
|
||||
def _encode_sentence(self, chunk: str): return [ tok for word in self._pat.findall(chunk) for tok in self._encode_word(word.encode()) ]
|
||||
def _encode_word(self, word: bytes):
|
||||
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
|
||||
parts = [word[i:i+1] for i in range(len(word))]
|
||||
while True:
|
||||
min_tid, min_idx = 2**32, -1
|
||||
for idx, (p1, p2) in enumerate(zip(parts[:-1], parts[1:])):
|
||||
tid = self._normal_tokens.get(p1 + p2, min_tid)
|
||||
if tid < min_tid: min_tid, min_idx = tid, idx
|
||||
if min_idx == -1: break
|
||||
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx+1]] + parts[min_idx+2:]
|
||||
try: return [ self._normal_tokens[p] for p in parts ]
|
||||
except KeyError: raise RuntimeError("token not found")
|
||||
|
||||
def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor:
|
||||
B, H, T, Hd = x.shape
|
||||
assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension"
|
||||
|
||||
Reference in New Issue
Block a user