diff --git a/test/null/test_llm_tokenizer.py b/test/null/test_llm_tokenizer.py index a71c1471e3..99e5ff5fa5 100644 --- a/test/null/test_llm_tokenizer.py +++ b/test/null/test_llm_tokenizer.py @@ -46,6 +46,18 @@ class TestLLMTokenizer(unittest.TestCase): def test_llama_repeat(self): self._test_coding(self.llama_tok, "00000000000000000", [ 931, 931, 931, 931, 931, 410 ]) def test_llama_pat(self): self._test_coding(self.llama_tok, "today\n \n", [ 31213, 14211 ]) + def test_tekken_from_gguf_kv(self): + kv = { + "tokenizer.ggml.tokens": ["", "", "", "[INST]", "[/INST]", "hello"], + "tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1], + "tokenizer.ggml.pre": "tekken", + } + tok = SimpleTokenizer.from_gguf_kv(kv) + self.assertEqual(tok.role("user"), [3]) + self.assertEqual(tok.encode("hello"), [5]) + self.assertEqual(tok.end_turn(2), [4]) + self.assertEqual(tok.role("assistant"), []) + def test_stream_decoder(self): """stream_decoder buffers incomplete UTF-8: token 25677 has 3/4 of emoji, token 138 completes it.""" bs = [*range(33, 127), *range(161, 173), *range(174, 256)] diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index fb8a3252e2..8f7f066f82 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -9,7 +9,7 @@ from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler class SimpleTokenizer: def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"): preset = {"qwen35":"qwen2","qwen35moe":"qwen2"}.get(preset, preset) - if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2"): raise ValueError(f"Invalid tokenizer preset '{preset}'") + if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2","tekken"): raise ValueError(f"Invalid tokenizer preset '{preset}'") # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9 bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)} @@ -63,11 +63,16 @@ class SimpleTokenizer: if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format if self.preset == 'kimi-k2': return self.encode("<|im_" + role + "|>" + role + "<|im_middle|>") if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n") + if self.preset == 'tekken': + if role == 'user': return self.encode("[INST]") + if role == 'assistant': return [] + raise ValueError(f"Unsupported role '{role}' for tokenizer preset '{self.preset}'") return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n") def end_turn(self, eos_id:int): if self.preset == 'olmo': return self.encode("\n") if self.preset == 'kimi-k2': return [eos_id] if self.preset == 'qwen2': return [eos_id] + self.encode("\n") + if self.preset == 'tekken': return self.encode("[/INST]") return [eos_id] @functools.cache