mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
BPE tokenizer (#11415)
* BPE works * refactor tok * oops * basic tests * fix eval * smaller diff * fix error * proper vocab decoding * use regex for splitting * escape ucatrange * full compat --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
2
test/external/external_llm_eval.py
vendored
2
test/external/external_llm_eval.py
vendored
@@ -10,7 +10,7 @@ if __name__ == "__main__":
|
||||
|
||||
model, kv = Transformer.from_gguf(Tensor.from_url(models["1B"]), max_context=4096)
|
||||
|
||||
tok = SimpleTokenizer(kv["tokenizer.ggml.tokens"])
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
bos_id: int = kv['tokenizer.ggml.bos_token_id']
|
||||
eos_id: int = kv['tokenizer.ggml.eos_token_id']
|
||||
|
||||
|
||||
12
test/external/external_test_simple_tokenizer.py
vendored
12
test/external/external_test_simple_tokenizer.py
vendored
@@ -1,17 +1,19 @@
|
||||
from transformers import AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
from tinygrad.apps.llm import SimpleTokenizer
|
||||
from tinygrad.helpers import tqdm, getenv
|
||||
from tinygrad.apps.llm import SimpleTokenizer, gpt2_decode_vocab, get_llama_re
|
||||
from tinygrad.helpers import tqdm, getenv, partition
|
||||
|
||||
# use ALLOW_FAILED=-1 to go over the entire dataset without printing.
|
||||
if __name__ == "__main__":
|
||||
base_tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||
vocab_words = [ word for word, _ in sorted(base_tokenizer.get_vocab().items(), key=lambda t: t[1]) ]
|
||||
special_tokens, normal_tokens = partition(((t, tid) for t, tid in base_tokenizer.vocab.items()),
|
||||
lambda e: e[1] in base_tokenizer.all_special_ids)
|
||||
inv_vocab = { tid: word for word, tid in base_tokenizer.get_vocab().items() }
|
||||
simple_tokenizer = SimpleTokenizer(vocab_words)
|
||||
simple_tokenizer = SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens))
|
||||
|
||||
color_codes = [ 91, 92, 94, 93, 95 ]
|
||||
def color_tokens(tids): return "".join(f"\033[{color_codes[i%len(color_codes)]}m{inv_vocab[t]}" for i, t in enumerate(tids)) + "\033[0m"
|
||||
def color_tokens(tids):
|
||||
return "".join(f"\033[{color_codes[i%len(color_codes)]}m{base_tokenizer.decode([t])}" for i, t in enumerate(tids)) + "\033[0m"
|
||||
|
||||
ds = load_dataset("OpenAssistant/oasst1")
|
||||
allow_failed = getenv("ALLOW_FAILED", 10)
|
||||
|
||||
57
test/unit/test_llm_tokenizer.py
Normal file
57
test/unit/test_llm_tokenizer.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import unittest, base64, functools
|
||||
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
class TestLLMTokenizer(unittest.TestCase):
|
||||
@functools.cached_property
|
||||
def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "<x>": 5, "<y>": 6, "<z>": 7 })
|
||||
|
||||
@functools.cached_property
|
||||
def llama_tok(self):
|
||||
# from https://github.com/tinygrad/tinygrad/blob/e0106b6b257ebc003eb3694144e3e198f7d8cc37/examples/llama3.py#L14
|
||||
model_file = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model")
|
||||
with open(model_file, "rt") as fd:
|
||||
str_vocab = [ line.split(maxsplit=1) for line in fd.read().splitlines() if line ]
|
||||
normal_tokens = { base64.b64decode(stok): int(srank) for stok, srank in str_vocab }
|
||||
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|reserved_special_token_2|>",
|
||||
"<|reserved_special_token_3|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|reserved_special_token_4|>",
|
||||
"<|eot_id|>",
|
||||
] + [ f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5) ]
|
||||
return SimpleTokenizer(get_llama_re(), normal_tokens, { token: len(normal_tokens) + i for i, token in enumerate(special_tokens) })
|
||||
|
||||
def _test_coding(self, tok: SimpleTokenizer, text: str, expected_tokens: list[int]):
|
||||
self.assertEqual(tok.encode(text), expected_tokens)
|
||||
self.assertEqual(tok.decode(expected_tokens), text)
|
||||
|
||||
def test_abc(self): self._test_coding(self.basic_tok, "abc", [ 3, 2 ])
|
||||
def test_abbc(self): self._test_coding(self.basic_tok, "abbc", [ 3, 4 ])
|
||||
def test_aabbbcc(self): self._test_coding(self.basic_tok, "aabbbcc", [ 0, 3, 1, 4, 2 ])
|
||||
def test_specials1(self): self._test_coding(self.basic_tok, "a<x>a<y>a<z>a", [ 0, 5, 0, 6, 0, 7, 0 ])
|
||||
def test_specials2(self): self._test_coding(self.basic_tok, "<x>a<y>a<z>", [ 5, 0, 6, 0, 7 ])
|
||||
def test_invalid_token(self):
|
||||
with self.assertRaises(RuntimeError): self._test_coding(self.basic_tok, "L", [])
|
||||
|
||||
def test_no_specials(self): self._test_coding(SimpleTokenizer(".*", { bytes([i]): i for i in range(256) }, {}), "abc", [97, 98, 99])
|
||||
|
||||
# NOTE: the correct tokenization for this can only be found by looking up the text chunk in the vocab, not by applying merges
|
||||
def test_llama_early_tokenize(self): self._test_coding(self.llama_tok, " например", [ 111797 ])
|
||||
|
||||
def test_llama_basic(self): self._test_coding(self.llama_tok, "hello world", [ 15339, 1917 ])
|
||||
def test_llama_control_char(self): self._test_coding(self.llama_tok, " \x850", [ 220, 116360, 15 ])
|
||||
def test_llama_bytes(self): self._test_coding(self.llama_tok, " \xec\x8b\xa4\xed", [ 1717, 105, 116174, 82638, 2483 ])
|
||||
def test_llama_special1(self): self._test_coding(self.llama_tok, "hello <|end_of_text|>", [ 15339, 220, 128001 ])
|
||||
def test_llama_special2(self): self._test_coding(self.llama_tok, "<|start_header_id|>user<|end_header_id|>\n\n", [ 128006, 882, 128007, 271 ])
|
||||
def test_llama_repeat(self): self._test_coding(self.llama_tok, "00000000000000000", [ 931, 931, 931, 931, 931, 410 ])
|
||||
def test_llama_pat(self): self._test_coding(self.llama_tok, "today\n \n", [ 31213, 14211 ])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,33 +1,57 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
||||
import sys, argparse, typing, re, itertools, unicodedata
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers
|
||||
|
||||
def gpt2_decode_vocab(voc: dict[str, int]): # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||
c2b = { chr(cp): cp for cp in itertools.chain(range(ord("!"), ord("~")+1), range(ord("¡"), ord("¬")+1), range(ord("®"), ord("ÿ")+1)) }
|
||||
c2b.update({ chr(256+off): cp for off, cp in enumerate(cp for cp in range(256) if chr(cp) not in c2b) })
|
||||
return { bytes(c2b[c] for c in tok): tid for tok, tid in voc.items() }
|
||||
|
||||
def get_llama_re():
|
||||
def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre))
|
||||
r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L")
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286
|
||||
return "(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \
|
||||
f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+"
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, vocab: list[str]):
|
||||
self.vocab: list[str] = vocab
|
||||
self.biggest_token: int = max(map(len, vocab))
|
||||
self.token_to_id: dict[str, int] = {tok: i for i, tok in enumerate(vocab)}
|
||||
self.replace_space = "Ġ"
|
||||
self.replace_newline = "Ċ"
|
||||
def __init__(self, pat: str, normal_tokens: dict[bytes, int], special_tokens: dict[str, int]):
|
||||
self._normal_tokens, self._special_tokens, self._pat = normal_tokens, special_tokens, re.compile(pat)
|
||||
self._tok2str = { tid: tok.encode() for tok, tid in special_tokens.items() } | { tid: tok for tok, tid in normal_tokens.items() }
|
||||
self._special_re = re.compile("|".join(re.escape(tok) for tok in self._special_tokens.keys()) if special_tokens else r"(?!)")
|
||||
|
||||
def encode(self, text:str) -> list[int]:
|
||||
s = text.replace(" ", self.replace_space).replace("\n", self.replace_newline)
|
||||
out: list[int] = []
|
||||
i = 0
|
||||
while i < len(s):
|
||||
j = min(i+self.biggest_token, len(s))
|
||||
while i < j and (tid:=self.token_to_id.get(s[i:j])) is None: j -= 1
|
||||
if tid is None: raise RuntimeError(f"token not found in {s}")
|
||||
assert tid is not None, f"token not found in {s}"
|
||||
out.append(tid)
|
||||
i = j
|
||||
return out
|
||||
@staticmethod
|
||||
def from_gguf_kv(kv: dict):
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
||||
if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
|
||||
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
||||
normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||
return SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens))
|
||||
|
||||
def decode(self, ids: list[int]) -> str:
|
||||
return ''.join(self.vocab[tid] for tid in ids).replace(self.replace_space, " ").replace(self.replace_newline, "\n")
|
||||
def encode(self, text: str):
|
||||
tokens: list[int] = []
|
||||
pos = 0
|
||||
for match in self._special_re.finditer(text):
|
||||
tokens.extend(self._encode_sentence(text[pos:match.start(0)]) + [self._special_tokens[text[match.start(0):match.end(0)]]])
|
||||
pos = match.end(0)
|
||||
return tokens + self._encode_sentence(text[pos:])
|
||||
|
||||
def role(self, role:str):
|
||||
return [t for x in ["<|start_header_id|>", role, "<|end_header_id|>\n\n"] for t in self.encode(x)] # llama style
|
||||
def decode(self, ids: list[int]) -> str: return b''.join(self._tok2str[tid] for tid in ids).decode()
|
||||
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
|
||||
def _encode_sentence(self, chunk: str): return [ tok for word in self._pat.findall(chunk) for tok in self._encode_word(word.encode()) ]
|
||||
def _encode_word(self, word: bytes):
|
||||
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
|
||||
parts = [word[i:i+1] for i in range(len(word))]
|
||||
while True:
|
||||
min_tid, min_idx = 2**32, -1
|
||||
for idx, (p1, p2) in enumerate(zip(parts[:-1], parts[1:])):
|
||||
tid = self._normal_tokens.get(p1 + p2, min_tid)
|
||||
if tid < min_tid: min_tid, min_idx = tid, idx
|
||||
if min_idx == -1: break
|
||||
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx+1]] + parts[min_idx+2:]
|
||||
try: return [ self._normal_tokens[p] for p in parts ]
|
||||
except KeyError: raise RuntimeError("token not found")
|
||||
|
||||
def apply_rope(x:Tensor, start_pos:int|UOp, base:int=10000):
|
||||
B, H, T, Hd = x.shape
|
||||
@@ -165,7 +189,7 @@ if __name__ == "__main__":
|
||||
model, kv = Transformer.from_gguf(Tensor.from_url(models[args.size]), args.max_context)
|
||||
|
||||
# extract some metadata
|
||||
tok = SimpleTokenizer(kv["tokenizer.ggml.tokens"])
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
bos_id: int = kv['tokenizer.ggml.bos_token_id']
|
||||
eos_id: int = kv['tokenizer.ggml.eos_token_id']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user