mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
qwen work
This commit is contained in:
@@ -10,6 +10,8 @@ class TestLLMServer(unittest.TestCase):
|
||||
cls.mock_tok.role = Mock(return_value=[100, 101])
|
||||
cls.mock_tok.encode = Mock(return_value=[200, 201, 202])
|
||||
cls.mock_tok.decode = Mock(return_value="Hello")
|
||||
cls.mock_tok.end_turn = Mock(return_value=[999])
|
||||
cls.mock_tok.stop_tokens = Mock(return_value={999})
|
||||
|
||||
cls.mock_model = Mock()
|
||||
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
|
||||
|
||||
@@ -4,7 +4,8 @@ from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
||||
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
|
||||
self._preset = preset
|
||||
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
||||
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||
@@ -23,10 +24,11 @@ class SimpleTokenizer:
|
||||
@staticmethod
|
||||
def from_gguf_kv(kv:dict):
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
||||
if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
|
||||
preset = kv["tokenizer.ggml.pre"]
|
||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
||||
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
||||
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens))
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), preset)
|
||||
|
||||
def _encode_word(self, word:bytes) -> list[int]:
|
||||
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
|
||||
@@ -49,7 +51,16 @@ class SimpleTokenizer:
|
||||
return tokens + self._encode_sentence(text[pos:])
|
||||
|
||||
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode()
|
||||
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
def role(self, role:str):
|
||||
if self._preset == "qwen2": return self.encode(f"<|im_start|>{role}\n")
|
||||
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
def end_turn(self):
|
||||
if self._preset == "qwen2": return self.encode("<|im_end|>\n")
|
||||
return self.encode("<|eot_id|>")
|
||||
def stop_tokens(self) -> set[int]:
|
||||
"""Returns tokens that indicate end of generation (subset of end_turn)"""
|
||||
if self._preset == "qwen2": return {self._special_tokens["<|im_end|>"]}
|
||||
return {self._special_tokens["<|eot_id|>"]}
|
||||
|
||||
def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor:
|
||||
B, H, T, Hd = x.shape
|
||||
@@ -64,22 +75,28 @@ def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor:
|
||||
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0):
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0, head_dim:int|None=None, rope_base:float=10000.0, qk_norm:bool=False):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.head_dim = dim // n_heads
|
||||
self.head_dim = head_dim if head_dim is not None else dim // n_heads
|
||||
self.max_context = max_context
|
||||
self.rope_base = rope_base
|
||||
|
||||
# --- attention projections (all linear, bias-free) ------------------
|
||||
kv_proj_out = self.head_dim * n_kv_heads # Llama-3 uses the same dim for K/V
|
||||
self.attn_q = nn.Linear(dim, dim, bias=False)
|
||||
q_proj_out = self.head_dim * n_heads
|
||||
kv_proj_out = self.head_dim * n_kv_heads
|
||||
self.attn_q = nn.Linear(dim, q_proj_out, bias=False)
|
||||
self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
|
||||
self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
|
||||
self.attn_output = nn.Linear(dim, dim, bias=False)
|
||||
self.attn_output = nn.Linear(q_proj_out, dim, bias=False)
|
||||
|
||||
# --- RMSNorms --------------------------------------------------------
|
||||
self.attn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
# QK normalization (Qwen3-style)
|
||||
if qk_norm:
|
||||
self.attn_q_norm = nn.RMSNorm(self.head_dim, norm_eps)
|
||||
self.attn_k_norm = nn.RMSNorm(self.head_dim, norm_eps)
|
||||
|
||||
# --- feed-forward ----------------------------------------------------
|
||||
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
||||
@@ -95,8 +112,12 @@ class TransformerBlock:
|
||||
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
|
||||
q = apply_rope(q, start_pos)
|
||||
k = apply_rope(k, start_pos)
|
||||
# QK normalization (applied before RoPE)
|
||||
if hasattr(self, "attn_q_norm"): q = self.attn_q_norm(q)
|
||||
if hasattr(self, "attn_k_norm"): k = self.attn_k_norm(k)
|
||||
|
||||
q = apply_rope(q, start_pos, self.rope_base)
|
||||
k = apply_rope(k, start_pos, self.rope_base)
|
||||
|
||||
# TODO: remove these kv cache realizes
|
||||
if not hasattr(self, "cache_kv"):
|
||||
@@ -121,8 +142,8 @@ class TransformerBlock:
|
||||
return self._feed_forward(self._attention(x, start_pos)).contiguous()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, max_context):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context) for _ in range(num_blocks)]
|
||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, max_context, head_dim=None, rope_base=10000.0, qk_norm=False):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, head_dim, rope_base, qk_norm) for _ in range(num_blocks)]
|
||||
self.token_embd = nn.Embedding(vocab_size, dim)
|
||||
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
@@ -152,9 +173,13 @@ class Transformer:
|
||||
|
||||
arch = kv['general.architecture']
|
||||
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
||||
head_dim = kv.get(f'{arch}.attention.key_length') # None for models that use dim // n_heads
|
||||
rope_base = kv.get(f'{arch}.rope.freq_base', 10000.0)
|
||||
qk_norm = 'blk.0.attn_q_norm.weight' in state_dict # detect QK normalization (Qwen3-style)
|
||||
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
|
||||
n_heads=kv[f'{arch}.attention.head_count'], n_kv_heads=kv[f'{arch}.attention.head_count_kv'],
|
||||
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
|
||||
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context,
|
||||
head_dim=head_dim, rope_base=rope_base, qk_norm=qk_norm)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
|
||||
@@ -178,6 +203,7 @@ models = {
|
||||
"llama3.2:3b": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf",
|
||||
"llama3.2:3b-f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf",
|
||||
"llama3.1:8b": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
|
||||
"qwen3:0.6b": "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf",
|
||||
}
|
||||
|
||||
# *** simple OpenAI compatible server on 11434 to match ollama ***
|
||||
@@ -187,9 +213,9 @@ class Handler(HTTPRequestHandler):
|
||||
def run_model(self, ids:list[int], model_name:str, include_usage=False):
|
||||
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
|
||||
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
|
||||
out = []
|
||||
out, stop_ids = [], tok.stop_tokens()
|
||||
for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1):
|
||||
if next_id == eos_id: break
|
||||
if next_id in stop_ids: break
|
||||
out.append(next_id)
|
||||
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
|
||||
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
|
||||
@@ -215,7 +241,8 @@ class Handler(HTTPRequestHandler):
|
||||
if c["type"] == "text": ids += tok.encode(c["text"])
|
||||
else: raise RuntimeError(f"unhandled type: {c['type']}")
|
||||
else: raise RuntimeError(f"unknown content type: {type(content)}")
|
||||
ids += tok.role("assistant")
|
||||
ids += tok.end_turn()
|
||||
ids += tok.role("assistant")
|
||||
|
||||
# reply
|
||||
chunks = self.run_model(ids, body["model"], not body.get("stream") or body.get("stream_options",{}).get("include_usage", False))
|
||||
@@ -247,14 +274,15 @@ if __name__ == "__main__":
|
||||
# start server
|
||||
if args.serve: TCPServerWithReuse(('', 11434), Handler).serve_forever()
|
||||
|
||||
stop_ids = tok.stop_tokens()
|
||||
ids: list[int] = [bos_id]
|
||||
while 1:
|
||||
start_pos = len(ids) - 1
|
||||
try:
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + [eos_id] + tok.role("assistant")
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn() + tok.role("assistant")
|
||||
except EOFError:
|
||||
break
|
||||
for next_id in model.generate(ids, start_pos):
|
||||
sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n")
|
||||
sys.stdout.write(tok.decode([next_id]) if next_id not in stop_ids else "\n\n")
|
||||
sys.stdout.flush()
|
||||
if next_id == eos_id: break
|
||||
if next_id in stop_ids: break
|
||||
|
||||
Reference in New Issue
Block a user