qwen model is working (#13690)

* qwen model is mostly working

* add Q4_K quantization support to GGUF parser, add qwen3:1.7b model

- Add Q4_K (type 12) dequantization in nn/state.py
- Add qwen3:1.7b model using Q4_K_M quantization (smaller than Q8_0)
- Make bos_token_id optional for models like Qwen3 that don't have it
- Fix line length issues and add preset parameter to SimpleTokenizer

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* smaller diff

* test dequant

* half split

* better

* simple tok

* mock token

* polish

* better

* fix

* replace

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
George Hotz
2025-12-15 18:00:34 -04:00
committed by GitHub
parent d43e4c7553
commit 321ab943b2
2 changed files with 43 additions and 24 deletions

View File

@@ -10,6 +10,7 @@ class TestLLMServer(unittest.TestCase):
cls.mock_tok.role = Mock(return_value=[100, 101])
cls.mock_tok.encode = Mock(return_value=[200, 201, 202])
cls.mock_tok.decode = Mock(return_value="Hello")
cls.mock_tok.end_turn = Mock(return_value=[998])
cls.mock_model = Mock()
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))

View File

@@ -4,7 +4,8 @@ from tinygrad import Tensor, nn, UOp, TinyJit, getenv
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored
class SimpleTokenizer:
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
if preset not in ("llama3","llama-v3","llama-bpe","qwen2"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
@@ -20,14 +21,14 @@ class SimpleTokenizer:
self._normal_tokens = {bytes(self._byte_decoder[c] for c in tok): tid for tok, tid in normal_tokens.items()}
self._special_tokens = special_tokens
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
self.preset = preset
@staticmethod
def from_gguf_kv(kv:dict):
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens))
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"])
def _encode_word(self, word:bytes) -> list[int]:
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
@@ -49,8 +50,11 @@ class SimpleTokenizer:
pos = match.end(0)
return tokens + self._encode_sentence(text[pos:])
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode()
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
def role(self, role:str):
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
def end_turn(self, eos_id:int): return [eos_id] + self.encode("\n") if self.preset == 'qwen2' else [eos_id]
@functools.cache
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
@@ -65,22 +69,26 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0):
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:bool=False):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
self.head_dim = head_dim
self.max_context = max_context
self.rope_theta = rope_theta
# --- attention projections (all linear, bias-free) ------------------
kv_proj_out = self.head_dim * n_kv_heads # Llama-3 uses the same dim for K/V
self.attn_q = nn.Linear(dim, dim, bias=False)
q_proj_out = self.head_dim * n_heads
kv_proj_out = self.head_dim * n_kv_heads
self.attn_q = nn.Linear(dim, q_proj_out, bias=False)
self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
self.attn_output = nn.Linear(dim, dim, bias=False)
self.attn_output = nn.Linear(q_proj_out, dim, bias=False)
# --- RMSNorms --------------------------------------------------------
self.attn_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(self.head_dim, norm_eps), nn.RMSNorm(self.head_dim, norm_eps)
# --- feed-forward ----------------------------------------------------
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
@@ -96,8 +104,10 @@ class TransformerBlock:
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
if hasattr(self, 'attn_q_norm'): q, k = self.attn_q_norm(q), self.attn_k_norm(k)
# TODO: make UOp have SupportsIndex
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context)[start_pos:start_pos+T] # type: ignore
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore
q = apply_rope(q, freqs_cis)
k = apply_rope(k, freqs_cis)
@@ -125,8 +135,10 @@ class TransformerBlock:
return self._feed_forward(self._attention(x, start_pos)).contiguous()
class Transformer:
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, max_context):
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context) for _ in range(num_blocks)]
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:bool=False):
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm)
for _ in range(num_blocks)]
self.token_embd = nn.Embedding(vocab_size, dim)
self.output_norm = nn.RMSNorm(dim, norm_eps)
self.output = nn.Linear(dim, vocab_size, bias=False)
@@ -159,13 +171,15 @@ class Transformer:
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
# permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...]
for name in state_dict:
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
if arch != 'qwen3':
for name in state_dict:
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
n_heads=n_heads, n_kv_heads=n_kv_heads,
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
vocab_size=len(kv['tokenizer.ggml.tokens']), head_dim=kv[f'{arch}.attention.key_length'],
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
@@ -190,6 +204,9 @@ models = {
"llama3.2:3b": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf",
"llama3.2:3b-f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf",
"llama3.1:8b": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
"qwen3:0.6b": "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf",
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
}
# *** simple OpenAI compatible server on 11434 to match ollama ***
@@ -255,7 +272,7 @@ class Handler(HTTPRequestHandler):
if DEBUG >= 1: print(json.dumps(body, indent=2))
if self.path == "/v1/chat/completions":
# extract tokens
ids = [bos_id]
ids: list[int] = [bos_id] if bos_id is not None else []
for msg in body["messages"]:
ids += tok.role(msg["role"])
# content can be a str or a list
@@ -266,7 +283,8 @@ class Handler(HTTPRequestHandler):
if c["type"] == "text": ids += tok.encode(c["text"])
else: raise RuntimeError(f"unhandled type: {c['type']}")
else: raise RuntimeError(f"unknown content type: {type(content)}")
ids += tok.role("assistant")
ids += tok.end_turn(eos_id)
ids += tok.role("assistant")
# reply
chunks = self.run_model(ids, body["model"], not body.get("stream") or body.get("stream_options",{}).get("include_usage", False))
@@ -302,17 +320,17 @@ if __name__ == "__main__":
# extract some metadata
tok = SimpleTokenizer.from_gguf_kv(kv)
bos_id: int = kv['tokenizer.ggml.bos_token_id']
bos_id: int|None = kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None
eos_id: int = kv['tokenizer.ggml.eos_token_id']
# start server
if args.serve: TCPServerWithReuse(('', args.serve), Handler).serve_forever()
ids: list[int] = [bos_id]
ids: list[int] = [bos_id] if bos_id is not None else []
while 1:
start_pos = len(ids) - 1
start_pos = max(len(ids) - 1, 0)
try:
ids += tok.role("user") + tok.encode(input('>>> ')) + [eos_id] + tok.role("assistant")
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn(eos_id) + tok.role("assistant")
except EOFError:
break
for next_id in model.generate(ids, start_pos):