mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
qwen model is working (#13690)
* qwen model is mostly working * add Q4_K quantization support to GGUF parser, add qwen3:1.7b model - Add Q4_K (type 12) dequantization in nn/state.py - Add qwen3:1.7b model using Q4_K_M quantization (smaller than Q8_0) - Make bos_token_id optional for models like Qwen3 that don't have it - Fix line length issues and add preset parameter to SimpleTokenizer 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * smaller diff * test dequant * half split * better * simple tok * mock token * polish * better * fix * replace --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,7 @@ class TestLLMServer(unittest.TestCase):
|
||||
cls.mock_tok.role = Mock(return_value=[100, 101])
|
||||
cls.mock_tok.encode = Mock(return_value=[200, 201, 202])
|
||||
cls.mock_tok.decode = Mock(return_value="Hello")
|
||||
cls.mock_tok.end_turn = Mock(return_value=[998])
|
||||
|
||||
cls.mock_model = Mock()
|
||||
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
|
||||
|
||||
@@ -4,7 +4,8 @@ from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
||||
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
|
||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
||||
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
||||
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||
@@ -20,14 +21,14 @@ class SimpleTokenizer:
|
||||
self._normal_tokens = {bytes(self._byte_decoder[c] for c in tok): tid for tok, tid in normal_tokens.items()}
|
||||
self._special_tokens = special_tokens
|
||||
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
|
||||
self.preset = preset
|
||||
|
||||
@staticmethod
|
||||
def from_gguf_kv(kv:dict):
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
||||
if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
|
||||
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
||||
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens))
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"])
|
||||
|
||||
def _encode_word(self, word:bytes) -> list[int]:
|
||||
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
|
||||
@@ -49,8 +50,11 @@ class SimpleTokenizer:
|
||||
pos = match.end(0)
|
||||
return tokens + self._encode_sentence(text[pos:])
|
||||
|
||||
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode()
|
||||
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
|
||||
def role(self, role:str):
|
||||
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
|
||||
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
def end_turn(self, eos_id:int): return [eos_id] + self.encode("\n") if self.preset == 'qwen2' else [eos_id]
|
||||
|
||||
@functools.cache
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
@@ -65,22 +69,26 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0):
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:bool=False):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.head_dim = dim // n_heads
|
||||
self.head_dim = head_dim
|
||||
self.max_context = max_context
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
# --- attention projections (all linear, bias-free) ------------------
|
||||
kv_proj_out = self.head_dim * n_kv_heads # Llama-3 uses the same dim for K/V
|
||||
self.attn_q = nn.Linear(dim, dim, bias=False)
|
||||
q_proj_out = self.head_dim * n_heads
|
||||
kv_proj_out = self.head_dim * n_kv_heads
|
||||
self.attn_q = nn.Linear(dim, q_proj_out, bias=False)
|
||||
self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
|
||||
self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
|
||||
self.attn_output = nn.Linear(dim, dim, bias=False)
|
||||
self.attn_output = nn.Linear(q_proj_out, dim, bias=False)
|
||||
|
||||
# --- RMSNorms --------------------------------------------------------
|
||||
self.attn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(self.head_dim, norm_eps), nn.RMSNorm(self.head_dim, norm_eps)
|
||||
|
||||
# --- feed-forward ----------------------------------------------------
|
||||
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
||||
@@ -96,8 +104,10 @@ class TransformerBlock:
|
||||
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
|
||||
if hasattr(self, 'attn_q_norm'): q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
# TODO: make UOp have SupportsIndex
|
||||
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context)[start_pos:start_pos+T] # type: ignore
|
||||
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore
|
||||
q = apply_rope(q, freqs_cis)
|
||||
k = apply_rope(k, freqs_cis)
|
||||
|
||||
@@ -125,8 +135,10 @@ class TransformerBlock:
|
||||
return self._feed_forward(self._attention(x, start_pos)).contiguous()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, max_context):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context) for _ in range(num_blocks)]
|
||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:bool=False):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm)
|
||||
for _ in range(num_blocks)]
|
||||
self.token_embd = nn.Embedding(vocab_size, dim)
|
||||
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
@@ -159,13 +171,15 @@ class Transformer:
|
||||
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
|
||||
|
||||
# permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...]
|
||||
for name in state_dict:
|
||||
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
|
||||
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
|
||||
if arch != 'qwen3':
|
||||
for name in state_dict:
|
||||
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
|
||||
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
|
||||
|
||||
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads,
|
||||
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
||||
vocab_size=len(kv['tokenizer.ggml.tokens']), head_dim=kv[f'{arch}.attention.key_length'],
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
|
||||
@@ -190,6 +204,9 @@ models = {
|
||||
"llama3.2:3b": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf",
|
||||
"llama3.2:3b-f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf",
|
||||
"llama3.1:8b": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
|
||||
"qwen3:0.6b": "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf",
|
||||
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
|
||||
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
|
||||
}
|
||||
|
||||
# *** simple OpenAI compatible server on 11434 to match ollama ***
|
||||
@@ -255,7 +272,7 @@ class Handler(HTTPRequestHandler):
|
||||
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
||||
if self.path == "/v1/chat/completions":
|
||||
# extract tokens
|
||||
ids = [bos_id]
|
||||
ids: list[int] = [bos_id] if bos_id is not None else []
|
||||
for msg in body["messages"]:
|
||||
ids += tok.role(msg["role"])
|
||||
# content can be a str or a list
|
||||
@@ -266,7 +283,8 @@ class Handler(HTTPRequestHandler):
|
||||
if c["type"] == "text": ids += tok.encode(c["text"])
|
||||
else: raise RuntimeError(f"unhandled type: {c['type']}")
|
||||
else: raise RuntimeError(f"unknown content type: {type(content)}")
|
||||
ids += tok.role("assistant")
|
||||
ids += tok.end_turn(eos_id)
|
||||
ids += tok.role("assistant")
|
||||
|
||||
# reply
|
||||
chunks = self.run_model(ids, body["model"], not body.get("stream") or body.get("stream_options",{}).get("include_usage", False))
|
||||
@@ -302,17 +320,17 @@ if __name__ == "__main__":
|
||||
|
||||
# extract some metadata
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
bos_id: int = kv['tokenizer.ggml.bos_token_id']
|
||||
bos_id: int|None = kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None
|
||||
eos_id: int = kv['tokenizer.ggml.eos_token_id']
|
||||
|
||||
# start server
|
||||
if args.serve: TCPServerWithReuse(('', args.serve), Handler).serve_forever()
|
||||
|
||||
ids: list[int] = [bos_id]
|
||||
ids: list[int] = [bos_id] if bos_id is not None else []
|
||||
while 1:
|
||||
start_pos = len(ids) - 1
|
||||
start_pos = max(len(ids) - 1, 0)
|
||||
try:
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + [eos_id] + tok.role("assistant")
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn(eos_id) + tok.role("assistant")
|
||||
except EOFError:
|
||||
break
|
||||
for next_id in model.generate(ids, start_pos):
|
||||
|
||||
Reference in New Issue
Block a user