diff --git a/test/null/test_llm_server.py b/test/null/test_llm_server.py index 1b91e053ad..5bc00eb84f 100644 --- a/test/null/test_llm_server.py +++ b/test/null/test_llm_server.py @@ -12,6 +12,7 @@ class TestLLMServer(unittest.TestCase): cls.mock_tok.decode = Mock(return_value="Hello") cls.mock_tok.stream_decoder = Mock(return_value=lambda tid=None: "Hello" if tid is not None else "") cls.mock_tok.end_turn = Mock(return_value=[998]) + cls.mock_tok.prefix = Mock(return_value=[1]) cls.mock_tok.preset = "llama3" cls.mock_model = Mock() @@ -27,6 +28,7 @@ class TestLLMServer(unittest.TestCase): llm_module.tok = cls.mock_tok llm_module.bos_id = cls.bos_id llm_module.eos_id = cls.eos_id + llm_module.eot_id = None from tinygrad.apps.llm import Handler from tinygrad.viz.serve import TCPServerWithReuse diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 4f5ae6b570..8eabc00d36 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -9,7 +9,8 @@ from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler class SimpleTokenizer: def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"): preset = {"qwen35":"qwen2","qwen35moe":"qwen2"}.get(preset, preset) - if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2","tekken"): raise ValueError(f"Invalid tokenizer preset '{preset}'") + if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2","tekken","glm4"): + raise ValueError(f"Invalid tokenizer preset '{preset}'") # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9 bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)} @@ -63,6 +64,7 @@ class SimpleTokenizer: if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format if self.preset == 'kimi-k2': return self.encode("<|im_" + role + "|>" + role + "<|im_middle|>") if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n") + if self.preset == 'glm4': return self.encode("<|" + role + "|>") if self.preset == 'tekken': if role == 'user': return self.encode("[INST]") if role == 'assistant': return [] @@ -72,8 +74,11 @@ class SimpleTokenizer: if self.preset == 'olmo': return self.encode("\n") if self.preset == 'kimi-k2': return [eos_id] if self.preset == 'qwen2': return [eos_id] + self.encode("\n") + if self.preset == 'glm4': return [] if self.preset == 'tekken': return self.encode("[/INST]") return [eos_id] + def prefix(self, bos_id:int|None) -> list[int]: + return ([] if bos_id is None else [bos_id]) + (self.encode("") if self.preset == 'glm4' else []) @functools.cache def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: @@ -129,6 +134,7 @@ class TransformerConfig: num_experts: int = 0 num_experts_per_tok: int = 0 norm_topk_prob: bool = False + q_lora_rank: int = 0 kv_lora_rank: int = 0 shared_expert_dim: int = 0 full_attention_interval: int = 0 @@ -259,7 +265,12 @@ class MLATransformerBlock(FFNBlock): def __init__(self, config:TransformerConfig): super().__init__(config) qk_nope_head_dim = config.head_dim - config.rope_dim - self.attn_q = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False) + if config.q_lora_rank > 0: + self.attn_q_a = nn.Linear(config.dim, config.q_lora_rank, bias=False) + self.attn_q_a_norm = nn.RMSNorm(config.q_lora_rank, config.norm_eps) + self.attn_q_b = nn.Linear(config.q_lora_rank, config.n_heads * config.head_dim, bias=False) + else: + self.attn_q = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False) self.attn_kv_a_mqa = nn.Linear(config.dim, config.kv_lora_rank + config.rope_dim, bias=False) self.attn_kv_a_norm = nn.RMSNorm(config.kv_lora_rank, config.norm_eps) self.attn_k_b = {"weight": Tensor.zeros(config.n_heads, config.kv_lora_rank, qk_nope_head_dim)} @@ -269,7 +280,8 @@ class MLATransformerBlock(FFNBlock): def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: B, T, _ = x.shape q_nope_head_dim = self.config.head_dim - self.config.rope_dim - q = self.attn_q(x).reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) + q_proj = self.attn_q_b(self.attn_q_a_norm(self.attn_q_a(x))) if self.config.q_lora_rank > 0 else self.attn_q(x) + q = q_proj.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) q_nope, q_rope = q[..., :q_nope_head_dim], q[..., q_nope_head_dim:] q = (q_nope @ self.attn_k_b["weight"].transpose(-1, -2)).cat(apply_rope(q_rope, self.freqs_cis[start_pos:start_pos+T]), dim=-1) @@ -407,7 +419,7 @@ class Transformer: # Permute RoPE weights from interleaved to half-split layout. for name in state_dict: - if 'attn_q.weight' in name and (arch == 'llama' or kv_lora_rank): + if ('attn_q.weight' in name or 'attn_q_b.weight' in name) and (arch == 'llama' or kv_lora_rank): w = state_dict[name].reshape(n_heads, state_dict[name].shape[0]//n_heads, -1) prefix = head_dim-rope_dim state_dict[name] = w[:, :prefix].cat(w[:, prefix:].rearrange("n (h two) d -> n (two h) d", two=2), dim=1).reshape(-1, w.shape[-1]) @@ -429,7 +441,7 @@ class Transformer: qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0, num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0), norm_topk_prob=kv.get(f'{arch}.expert_weights_norm', arch in ('qwen3moe', 'qwen35moe')), - kv_lora_rank=kv_lora_rank, + kv_lora_rank=kv_lora_rank, q_lora_rank=kv.get(f'{arch}.attention.q_lora_rank', 0), leading_dense_blocks=kv.get(f'{arch}.leading_dense_block_count', 0), shared_expert_dim=kv.get( f'{arch}.expert_shared_feed_forward_length', @@ -489,6 +501,7 @@ models = { "qwen3.5:35b-a3b": "https://huggingface.co/unsloth/Qwen3.5-35B-A3B-GGUF/resolve/main/Qwen3.5-35B-A3B-Q4_K_M.gguf", "olmoe": "https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF/resolve/main/olmoe-1b-7b-0924-instruct-q4_k_m.gguf", "moonlight": "https://huggingface.co/gabriellarson/Moonlight-16B-A3B-Instruct-GGUF/resolve/main/Moonlight-16B-A3B-Instruct-Q4_K_M.gguf", + "glm-4.7-flash": "https://huggingface.co/unsloth/GLM-4.7-Flash-GGUF/resolve/main/GLM-4.7-Flash-Q4_K_M.gguf", } # *** simple OpenAI API compatible server with web interface on http://localhost:8000/ *** @@ -549,7 +562,7 @@ class Handler(HTTPRequestHandler): dec = tok.stream_decoder() for next_id in model.generate(ids, temperature=temperature): if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ") - if next_id == eos_id: break + if next_id in (eos_id, eot_id): break out.append(next_id) yield {"choices": [{"index":0, "delta":{"content":dec(next_id)}, "finish_reason":None}], **tmpl} if max_tokens is not None and len(out) >= max_tokens: @@ -569,7 +582,7 @@ class Handler(HTTPRequestHandler): if DEBUG >= 1: print(json.dumps(body, indent=2)) if self.path == "/v1/chat/completions": # extract tokens, last assistant message is treated as prefill - ids: list[int] = [bos_id] if bos_id is not None else [] + ids: list[int] = tok.prefix(bos_id) for i, msg in enumerate(body["messages"]): ids += tok.role(msg["role"]) content = msg["content"] @@ -621,6 +634,7 @@ if __name__ == "__main__": tok = SimpleTokenizer.from_gguf_kv(kv) bos_id: int|None = kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None eos_id: int = kv['tokenizer.ggml.eos_token_id'] + eot_id: int|None = kv.get('tokenizer.ggml.eot_token_id') # warmup the JIT if args.warmup or args.serve: @@ -642,7 +656,7 @@ if __name__ == "__main__": exit(0) # interactive chat - ids: list[int] = [bos_id] if bos_id is not None else [] + ids: list[int] = tok.prefix(bos_id) while 1: try: ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn(eos_id) + tok.role("assistant") @@ -650,6 +664,6 @@ if __name__ == "__main__": break dec = tok.stream_decoder() for next_id in model.generate(ids): - sys.stdout.write(dec(next_id) if next_id != eos_id else dec() + "\n\n") + sys.stdout.write(dec(next_id) if next_id not in (eos_id, eot_id) else dec() + "\n\n") sys.stdout.flush() - if next_id == eos_id: break + if next_id in (eos_id, eot_id): break