mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llm: support assistant prefill + refactor to TransformerConfig (#15457)
* llm: support assistant prefill * refactor to ModelConfig * TransformerConfig * more
This commit is contained in:
8
test/external/external_llm_eval.py
vendored
8
test/external/external_llm_eval.py
vendored
@@ -13,7 +13,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--max_tokens", "-T", type=int, default=4096)
|
||||
parser.add_argument("--offset", "-O", type=int, default=0)
|
||||
parser.add_argument("--temperature", "-t", type=float, default=0.0)
|
||||
parser.add_argument("--no_think", action="store_true", help="append /no_think to disable thinking (for Qwen3)")
|
||||
parser.add_argument("--no_think", action="store_true", help="disable thinking (prefills empty think block via assistant message)")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -29,8 +29,10 @@ if __name__ == "__main__":
|
||||
phrasing = "Given the following question and four candidate answers (A, B, C and D), choose the best answer.\n" +\
|
||||
f"Question: {question}\n" + '\n'.join([f"{l}. {t}" for l, t in zip(LABEL, choices['text'])]) +\
|
||||
'\nYour response should end with "The best answer is [the_answer_letter]"' +\
|
||||
" where the [the_answer_letter] is one of A, B, C or D." + (" /no_think" if args.no_think else "")
|
||||
resp = client.chat.completions.create(model="test", messages=[{"role": "user", "content": phrasing}],
|
||||
" where the [the_answer_letter] is one of A, B, C or D."
|
||||
messages = [{"role": "user", "content": phrasing}]
|
||||
if args.no_think: messages.append({"role": "assistant", "content": "<think>\n\n</think>\n\n"})
|
||||
resp = client.chat.completions.create(model="test", messages=messages,
|
||||
max_tokens=args.max_tokens, temperature=args.temperature)
|
||||
# normalize answer key (some use 1/2/3/4 instead of A/B/C/D)
|
||||
correct = answer.as_py().strip()
|
||||
|
||||
@@ -153,6 +153,49 @@ class TestLLMServer(unittest.TestCase):
|
||||
self.assertEqual(resp.choices[0].finish_reason, "length")
|
||||
self.assertEqual(resp.usage.completion_tokens, 2)
|
||||
|
||||
def test_assistant_prefill(self):
|
||||
"""Last assistant message should be treated as prefill (not a completed turn)."""
|
||||
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
|
||||
captured_ids = []
|
||||
orig_generate = self.mock_model.generate.side_effect
|
||||
def capture_generate(ids, **kwargs):
|
||||
captured_ids.extend(ids)
|
||||
return orig_generate(ids, **kwargs)
|
||||
self.mock_model.generate = Mock(side_effect=capture_generate)
|
||||
|
||||
resp = self.client.chat.completions.create(
|
||||
model="test", messages=[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Sure"}
|
||||
], stream=False
|
||||
)
|
||||
# prefill tokens should be in ids: role("assistant") + encode("Sure") but NO end_turn after it
|
||||
# and NO extra role("assistant") appended
|
||||
role_tokens = self.mock_tok.role.call_args_list
|
||||
# last role() call should be for "assistant" (the prefill message), not an extra one
|
||||
self.assertEqual(role_tokens[-1], unittest.mock.call("assistant"))
|
||||
# end_turn should be called once less than role() — the prefill assistant msg doesn't get end_turn
|
||||
self.assertEqual(self.mock_tok.end_turn.call_count, self.mock_tok.role.call_count - 1)
|
||||
self.assertIsNotNone(resp.choices[0].message.content)
|
||||
|
||||
def test_assistant_prefill_not_last(self):
|
||||
"""Assistant message that's NOT last should be a normal completed turn."""
|
||||
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
|
||||
self.mock_tok.role.reset_mock()
|
||||
self.mock_tok.end_turn.reset_mock()
|
||||
self.client.chat.completions.create(
|
||||
model="test", messages=[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Sure"},
|
||||
{"role": "user", "content": "Continue"}
|
||||
], stream=False
|
||||
)
|
||||
# all messages get end_turn, plus an extra role("assistant") at the end
|
||||
# roles: user, assistant, user, assistant(generation prompt) = 4 role calls
|
||||
# end_turns: user, assistant, user = 3 end_turn calls (one per message)
|
||||
self.assertEqual(self.mock_tok.end_turn.call_count, 3)
|
||||
self.assertEqual(self.mock_tok.role.call_count, 4)
|
||||
|
||||
def test_models_endpoint(self):
|
||||
import requests as req
|
||||
resp = req.get(f"http://127.0.0.1:{self.port}/v1/models")
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.apps.llm import TransformerBlock, TransformerConfig
|
||||
|
||||
def _moe_config(dim=8, hidden=16, n_heads=2, num_experts=4, num_experts_per_tok=2):
|
||||
return TransformerConfig(num_blocks=1, dim=dim, hidden_dim=hidden, n_heads=n_heads, n_kv_heads=n_heads,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=dim//n_heads, rope_theta=10000, max_context=16,
|
||||
num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)
|
||||
|
||||
class TestMoEFeedForward(unittest.TestCase):
|
||||
def test_moe_feed_forward(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
dim, hidden, n_heads = 8, 16, 2
|
||||
num_experts, k = 4, 2
|
||||
|
||||
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
|
||||
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
|
||||
block = TransformerBlock(_moe_config(dim, hidden, n_heads, num_experts, k))
|
||||
|
||||
# set up weights: gate scales by (expert_id+1), up/down are identity-ish, router picks experts 0,2
|
||||
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
|
||||
@@ -27,12 +31,10 @@ class TestMoEFeedForward(unittest.TestCase):
|
||||
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
|
||||
|
||||
def test_moe_feed_forward_batched(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
dim, hidden, n_heads = 8, 16, 2
|
||||
num_experts, k = 4, 2
|
||||
|
||||
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
|
||||
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
|
||||
block = TransformerBlock(_moe_config(dim, hidden, n_heads, num_experts, k))
|
||||
|
||||
# same setup as BS=1 test
|
||||
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
|
||||
@@ -50,13 +52,11 @@ class TestMoEFeedForward(unittest.TestCase):
|
||||
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2)
|
||||
|
||||
def test_moe_feed_forward_norm_topk_prob(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
dim, hidden, n_heads = 8, 16, 2
|
||||
num_experts, k = 4, 2
|
||||
|
||||
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
|
||||
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
|
||||
block.norm_topk_prob = True
|
||||
from dataclasses import replace
|
||||
block = TransformerBlock(replace(_moe_config(dim, hidden, n_heads, num_experts, k), norm_topk_prob=True))
|
||||
|
||||
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
|
||||
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
|
||||
|
||||
@@ -2,13 +2,15 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
from tinygrad import Tensor, UOp
|
||||
from tinygrad.engine.schedule import schedule_cache
|
||||
from tinygrad.apps.llm import Transformer, TransformerConfig
|
||||
|
||||
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
||||
|
||||
class TestTransformerGenerate(unittest.TestCase):
|
||||
def test_kv_cache_reuse(self):
|
||||
"""Test that generate reuses the KV cache when tokens extend the cached prefix."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
||||
model = Transformer(TEST_CONFIG)
|
||||
|
||||
captured_inputs = []
|
||||
def mock_call(self, tokens, start_pos, temperature):
|
||||
@@ -35,9 +37,7 @@ class TestTransformerGenerate(unittest.TestCase):
|
||||
|
||||
def test_kv_cache_invalidation(self):
|
||||
"""Test that generate invalidates the KV cache when tokens diverge from the cached prefix."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
||||
model = Transformer(TEST_CONFIG)
|
||||
|
||||
captured_inputs = []
|
||||
def mock_call(self, tokens, start_pos, temperature):
|
||||
@@ -61,9 +61,8 @@ class TestTransformerGenerate(unittest.TestCase):
|
||||
|
||||
def test_two_prompts_schedule_cache(self):
|
||||
"""Third prompt should hit the schedule cache, not miss (first two warm up both jits: prefill + decode)."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64)
|
||||
from dataclasses import replace
|
||||
model = Transformer(replace(TEST_CONFIG, max_context=64))
|
||||
|
||||
# first two prompts warm up both jits (prefill + decode)
|
||||
ids = list(range(1, 6))
|
||||
@@ -85,10 +84,9 @@ class TestTransformerGenerate(unittest.TestCase):
|
||||
|
||||
def test_chunked_prefill(self):
|
||||
"""When prompt > chunk_size, all chunks should be prefill"""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
from tinygrad.uop.ops import resolve
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64)
|
||||
from dataclasses import replace
|
||||
model = Transformer(replace(TEST_CONFIG, max_context=64))
|
||||
|
||||
def get_prefill_flags(tokens, chunk_size):
|
||||
is_prefill = []
|
||||
@@ -110,9 +108,7 @@ class TestTransformerGenerate(unittest.TestCase):
|
||||
|
||||
def test_temperature_zero_is_greedy(self):
|
||||
"""Temperature 0 (or near 0) should produce deterministic output."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
||||
model = Transformer(TEST_CONFIG)
|
||||
tokens = list(range(1, 6))
|
||||
results = [list(zip(range(5), model.generate(list(tokens)))) for _ in range(3)]
|
||||
# all runs should produce the same tokens
|
||||
@@ -121,9 +117,7 @@ class TestTransformerGenerate(unittest.TestCase):
|
||||
|
||||
def test_temperature_high_produces_variety(self):
|
||||
"""High temperature should produce different outputs across runs."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
||||
model = Transformer(TEST_CONFIG)
|
||||
tokens = list(range(1, 6))
|
||||
runs = set()
|
||||
for _ in range(5):
|
||||
@@ -135,9 +129,7 @@ class TestTransformerGenerate(unittest.TestCase):
|
||||
|
||||
def test_temperature_passed_to_forward(self):
|
||||
"""Temperature from generate should be passed through to __call__."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
||||
model = Transformer(TEST_CONFIG)
|
||||
captured_temps = []
|
||||
def mock_call(self, tokens, start_pos, temperature):
|
||||
captured_temps.append(float(temperature.item()))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, itertools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
|
||||
from tinygrad.uop.ops import resolve
|
||||
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context
|
||||
@@ -82,53 +83,62 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransformerConfig:
|
||||
num_blocks: int
|
||||
dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
head_dim: int
|
||||
rope_theta: float
|
||||
max_context: int = 0
|
||||
qk_norm: int = 0
|
||||
num_experts: int = 0
|
||||
num_experts_per_tok: int = 0
|
||||
norm_topk_prob: bool = False
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.max_context = max_context
|
||||
self.qk_norm = qk_norm
|
||||
def __init__(self, config:TransformerConfig):
|
||||
self.config = config
|
||||
|
||||
# --- attention projections (all linear, bias-free) ------------------
|
||||
q_proj_out = self.head_dim * n_heads
|
||||
kv_proj_out = self.head_dim * n_kv_heads
|
||||
self.attn_q = nn.Linear(dim, q_proj_out, bias=False)
|
||||
self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
|
||||
self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
|
||||
self.attn_output = nn.Linear(q_proj_out, dim, bias=False)
|
||||
q_proj_out = config.head_dim * config.n_heads
|
||||
kv_proj_out = config.head_dim * config.n_kv_heads
|
||||
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False)
|
||||
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False)
|
||||
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
|
||||
self.attn_output = nn.Linear(q_proj_out, config.dim, bias=False)
|
||||
|
||||
# --- RMSNorms --------------------------------------------------------
|
||||
self.attn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(qk_norm, norm_eps), nn.RMSNorm(qk_norm, norm_eps)
|
||||
self.attn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps)
|
||||
|
||||
# --- feed-forward (MoE or dense) -------------------------------------
|
||||
if num_experts > 0:
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router
|
||||
self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim)
|
||||
self.ffn_up_exps = ExpertWeights(num_experts, dim, hidden_dim)
|
||||
self.ffn_down_exps = ExpertWeights(num_experts, hidden_dim, dim)
|
||||
if config.num_experts > 0:
|
||||
self.ffn_gate_inp = nn.Linear(config.dim, config.num_experts, bias=False) # router
|
||||
self.ffn_gate_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
|
||||
self.ffn_up_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
|
||||
self.ffn_down_exps = ExpertWeights(config.num_experts, config.hidden_dim, config.dim)
|
||||
else:
|
||||
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.ffn_gate = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
|
||||
@function
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
x_norm = self.attn_norm(x) # (B,T,D)
|
||||
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
||||
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
if self.config.qk_norm and self.config.qk_norm != self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
B, T, _ = x.shape
|
||||
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
||||
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
q = q.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
||||
k = k.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
if self.config.qk_norm == self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T])
|
||||
k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T])
|
||||
@@ -155,8 +165,8 @@ class TransformerBlock:
|
||||
h_norm = self.ffn_norm(h)
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
|
||||
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each
|
||||
if self.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
|
||||
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.config.num_experts_per_tok) # (B, T, k) each
|
||||
if self.config.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
|
||||
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D)
|
||||
return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
|
||||
# TODO: remove the need for this contiguous
|
||||
@@ -166,22 +176,20 @@ class TransformerBlock:
|
||||
def __call__(self, x: Tensor, start_pos: int|UOp):
|
||||
if not hasattr(self, "cache_kv"):
|
||||
# TODO: how is the dtype of this determined?
|
||||
self.cache_kv = Tensor.empty(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device)
|
||||
self.freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)
|
||||
self.cache_kv = Tensor.empty(2, x.shape[0], self.config.n_kv_heads, self.config.max_context, self.config.head_dim, device=x.device)
|
||||
self.freqs_cis = precompute_freqs_cis(self.config.head_dim, self.config.max_context, self.config.rope_theta)
|
||||
# we pass in the weights implicitly so we unpack the GGUF on the fly
|
||||
@function(precompile=True, allow_implicit=True)
|
||||
def _run(x:Tensor, start_pos:int|UOp): return self._feed_forward(self._attention(x, start_pos)).contiguous()
|
||||
return _run(x, start_pos)
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
|
||||
num_experts, num_experts_per_tok, norm_topk_prob) for _ in range(num_blocks)]
|
||||
self.token_embd = nn.Embedding(vocab_size, dim)
|
||||
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
def __init__(self, config:TransformerConfig):
|
||||
self.blk = [TransformerBlock(config) for _ in range(config.num_blocks)]
|
||||
self.token_embd = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.output_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
||||
self.max_context = config.max_context
|
||||
self._cached_tokens: list[int] = []
|
||||
# we specialize the JIT for prefill and rollout
|
||||
self.prefill_jit = TinyJit(self.forward)
|
||||
@@ -218,15 +226,17 @@ class Transformer:
|
||||
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
|
||||
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
|
||||
|
||||
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
||||
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
||||
vocab_size=len(kv['tokenizer.ggml.tokens']),
|
||||
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
|
||||
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
|
||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
|
||||
norm_topk_prob=True if arch=='qwen3moe' else False)
|
||||
config = TransformerConfig(
|
||||
num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
||||
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
||||
vocab_size=len(kv['tokenizer.ggml.tokens']),
|
||||
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
|
||||
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
|
||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
|
||||
norm_topk_prob=arch == 'qwen3moe')
|
||||
model = Transformer(config)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
if realize:
|
||||
@@ -342,11 +352,10 @@ class Handler(HTTPRequestHandler):
|
||||
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
||||
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
||||
if self.path == "/v1/chat/completions":
|
||||
# extract tokens
|
||||
# extract tokens, last assistant message is treated as prefill
|
||||
ids: list[int] = [bos_id] if bos_id is not None else []
|
||||
for msg in body["messages"]:
|
||||
for i, msg in enumerate(body["messages"]):
|
||||
ids += tok.role(msg["role"])
|
||||
# content can be a str or a list
|
||||
content = msg["content"]
|
||||
if isinstance(content, str): ids += tok.encode(content)
|
||||
elif isinstance(content, list):
|
||||
@@ -354,8 +363,9 @@ class Handler(HTTPRequestHandler):
|
||||
if c["type"] == "text": ids += tok.encode(c["text"])
|
||||
else: raise RuntimeError(f"unhandled type: {c['type']}")
|
||||
else: raise RuntimeError(f"unknown content type: {type(content)}")
|
||||
if msg["role"] == "assistant" and i == len(body["messages"]) - 1: break
|
||||
ids += tok.end_turn(eos_id)
|
||||
ids += tok.role("assistant")
|
||||
else: ids += tok.role("assistant")
|
||||
|
||||
# reply
|
||||
max_tokens = body.get("max_completion_tokens") or body.get("max_tokens")
|
||||
@@ -383,8 +393,7 @@ if __name__ == "__main__":
|
||||
# load the model
|
||||
raw_model = Tensor.from_url(models[args.model])
|
||||
model, kv = Transformer.from_gguf(raw_model, args.max_context)
|
||||
if DEBUG >= 1 or args.benchmark:
|
||||
print(f"using model {args.model} with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
|
||||
print(f"using model {args.model} with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
|
||||
del raw_model
|
||||
|
||||
# TODO: why this is required to free the RAM of the GGUF copy?
|
||||
|
||||
Reference in New Issue
Block a user