llm: support assistant prefill + refactor to TransformerConfig (#15457)

* llm: support assistant prefill

* refactor to ModelConfig

* TransformerConfig

* more
This commit is contained in:
George Hotz
2026-03-25 10:50:48 +08:00
committed by GitHub
parent fd92aec094
commit fe2690399b
5 changed files with 139 additions and 93 deletions

View File

@@ -13,7 +13,7 @@ if __name__ == "__main__":
parser.add_argument("--max_tokens", "-T", type=int, default=4096)
parser.add_argument("--offset", "-O", type=int, default=0)
parser.add_argument("--temperature", "-t", type=float, default=0.0)
parser.add_argument("--no_think", action="store_true", help="append /no_think to disable thinking (for Qwen3)")
parser.add_argument("--no_think", action="store_true", help="disable thinking (prefills empty think block via assistant message)")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
@@ -29,8 +29,10 @@ if __name__ == "__main__":
phrasing = "Given the following question and four candidate answers (A, B, C and D), choose the best answer.\n" +\
f"Question: {question}\n" + '\n'.join([f"{l}. {t}" for l, t in zip(LABEL, choices['text'])]) +\
'\nYour response should end with "The best answer is [the_answer_letter]"' +\
" where the [the_answer_letter] is one of A, B, C or D." + (" /no_think" if args.no_think else "")
resp = client.chat.completions.create(model="test", messages=[{"role": "user", "content": phrasing}],
" where the [the_answer_letter] is one of A, B, C or D."
messages = [{"role": "user", "content": phrasing}]
if args.no_think: messages.append({"role": "assistant", "content": "<think>\n\n</think>\n\n"})
resp = client.chat.completions.create(model="test", messages=messages,
max_tokens=args.max_tokens, temperature=args.temperature)
# normalize answer key (some use 1/2/3/4 instead of A/B/C/D)
correct = answer.as_py().strip()

View File

@@ -153,6 +153,49 @@ class TestLLMServer(unittest.TestCase):
self.assertEqual(resp.choices[0].finish_reason, "length")
self.assertEqual(resp.usage.completion_tokens, 2)
def test_assistant_prefill(self):
"""Last assistant message should be treated as prefill (not a completed turn)."""
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
captured_ids = []
orig_generate = self.mock_model.generate.side_effect
def capture_generate(ids, **kwargs):
captured_ids.extend(ids)
return orig_generate(ids, **kwargs)
self.mock_model.generate = Mock(side_effect=capture_generate)
resp = self.client.chat.completions.create(
model="test", messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Sure"}
], stream=False
)
# prefill tokens should be in ids: role("assistant") + encode("Sure") but NO end_turn after it
# and NO extra role("assistant") appended
role_tokens = self.mock_tok.role.call_args_list
# last role() call should be for "assistant" (the prefill message), not an extra one
self.assertEqual(role_tokens[-1], unittest.mock.call("assistant"))
# end_turn should be called once less than role() — the prefill assistant msg doesn't get end_turn
self.assertEqual(self.mock_tok.end_turn.call_count, self.mock_tok.role.call_count - 1)
self.assertIsNotNone(resp.choices[0].message.content)
def test_assistant_prefill_not_last(self):
"""Assistant message that's NOT last should be a normal completed turn."""
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
self.mock_tok.role.reset_mock()
self.mock_tok.end_turn.reset_mock()
self.client.chat.completions.create(
model="test", messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Sure"},
{"role": "user", "content": "Continue"}
], stream=False
)
# all messages get end_turn, plus an extra role("assistant") at the end
# roles: user, assistant, user, assistant(generation prompt) = 4 role calls
# end_turns: user, assistant, user = 3 end_turn calls (one per message)
self.assertEqual(self.mock_tok.end_turn.call_count, 3)
self.assertEqual(self.mock_tok.role.call_count, 4)
def test_models_endpoint(self):
import requests as req
resp = req.get(f"http://127.0.0.1:{self.port}/v1/models")

View File

@@ -1,15 +1,19 @@
import unittest
import numpy as np
from tinygrad import Tensor
from tinygrad.apps.llm import TransformerBlock, TransformerConfig
def _moe_config(dim=8, hidden=16, n_heads=2, num_experts=4, num_experts_per_tok=2):
return TransformerConfig(num_blocks=1, dim=dim, hidden_dim=hidden, n_heads=n_heads, n_kv_heads=n_heads,
norm_eps=1e-5, vocab_size=100, head_dim=dim//n_heads, rope_theta=10000, max_context=16,
num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)
class TestMoEFeedForward(unittest.TestCase):
def test_moe_feed_forward(self):
from tinygrad.apps.llm import TransformerBlock
dim, hidden, n_heads = 8, 16, 2
num_experts, k = 4, 2
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
block = TransformerBlock(_moe_config(dim, hidden, n_heads, num_experts, k))
# set up weights: gate scales by (expert_id+1), up/down are identity-ish, router picks experts 0,2
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
@@ -27,12 +31,10 @@ class TestMoEFeedForward(unittest.TestCase):
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
def test_moe_feed_forward_batched(self):
from tinygrad.apps.llm import TransformerBlock
dim, hidden, n_heads = 8, 16, 2
num_experts, k = 4, 2
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
block = TransformerBlock(_moe_config(dim, hidden, n_heads, num_experts, k))
# same setup as BS=1 test
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
@@ -50,13 +52,11 @@ class TestMoEFeedForward(unittest.TestCase):
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2)
def test_moe_feed_forward_norm_topk_prob(self):
from tinygrad.apps.llm import TransformerBlock
dim, hidden, n_heads = 8, 16, 2
num_experts, k = 4, 2
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
block.norm_topk_prob = True
from dataclasses import replace
block = TransformerBlock(replace(_moe_config(dim, hidden, n_heads, num_experts, k), norm_topk_prob=True))
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])

View File

@@ -2,13 +2,15 @@ import unittest
from unittest.mock import patch
from tinygrad import Tensor, UOp
from tinygrad.engine.schedule import schedule_cache
from tinygrad.apps.llm import Transformer, TransformerConfig
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
class TestTransformerGenerate(unittest.TestCase):
def test_kv_cache_reuse(self):
"""Test that generate reuses the KV cache when tokens extend the cached prefix."""
from tinygrad.apps.llm import Transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
model = Transformer(TEST_CONFIG)
captured_inputs = []
def mock_call(self, tokens, start_pos, temperature):
@@ -35,9 +37,7 @@ class TestTransformerGenerate(unittest.TestCase):
def test_kv_cache_invalidation(self):
"""Test that generate invalidates the KV cache when tokens diverge from the cached prefix."""
from tinygrad.apps.llm import Transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
model = Transformer(TEST_CONFIG)
captured_inputs = []
def mock_call(self, tokens, start_pos, temperature):
@@ -61,9 +61,8 @@ class TestTransformerGenerate(unittest.TestCase):
def test_two_prompts_schedule_cache(self):
"""Third prompt should hit the schedule cache, not miss (first two warm up both jits: prefill + decode)."""
from tinygrad.apps.llm import Transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64)
from dataclasses import replace
model = Transformer(replace(TEST_CONFIG, max_context=64))
# first two prompts warm up both jits (prefill + decode)
ids = list(range(1, 6))
@@ -85,10 +84,9 @@ class TestTransformerGenerate(unittest.TestCase):
def test_chunked_prefill(self):
"""When prompt > chunk_size, all chunks should be prefill"""
from tinygrad.apps.llm import Transformer
from tinygrad.uop.ops import resolve
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64)
from dataclasses import replace
model = Transformer(replace(TEST_CONFIG, max_context=64))
def get_prefill_flags(tokens, chunk_size):
is_prefill = []
@@ -110,9 +108,7 @@ class TestTransformerGenerate(unittest.TestCase):
def test_temperature_zero_is_greedy(self):
"""Temperature 0 (or near 0) should produce deterministic output."""
from tinygrad.apps.llm import Transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
model = Transformer(TEST_CONFIG)
tokens = list(range(1, 6))
results = [list(zip(range(5), model.generate(list(tokens)))) for _ in range(3)]
# all runs should produce the same tokens
@@ -121,9 +117,7 @@ class TestTransformerGenerate(unittest.TestCase):
def test_temperature_high_produces_variety(self):
"""High temperature should produce different outputs across runs."""
from tinygrad.apps.llm import Transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
model = Transformer(TEST_CONFIG)
tokens = list(range(1, 6))
runs = set()
for _ in range(5):
@@ -135,9 +129,7 @@ class TestTransformerGenerate(unittest.TestCase):
def test_temperature_passed_to_forward(self):
"""Temperature from generate should be passed through to __call__."""
from tinygrad.apps.llm import Transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
model = Transformer(TEST_CONFIG)
captured_temps = []
def mock_call(self, tokens, start_pos, temperature):
captured_temps.append(float(temperature.item()))

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, itertools
from dataclasses import dataclass
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
from tinygrad.uop.ops import resolve
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context
@@ -82,53 +83,62 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
@dataclass(frozen=True)
class TransformerConfig:
num_blocks: int
dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
head_dim: int
rope_theta: float
max_context: int = 0
qk_norm: int = 0
num_experts: int = 0
num_experts_per_tok: int = 0
norm_topk_prob: bool = False
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.rope_theta = rope_theta
self.max_context = max_context
self.qk_norm = qk_norm
def __init__(self, config:TransformerConfig):
self.config = config
# --- attention projections (all linear, bias-free) ------------------
q_proj_out = self.head_dim * n_heads
kv_proj_out = self.head_dim * n_kv_heads
self.attn_q = nn.Linear(dim, q_proj_out, bias=False)
self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
self.attn_output = nn.Linear(q_proj_out, dim, bias=False)
q_proj_out = config.head_dim * config.n_heads
kv_proj_out = config.head_dim * config.n_kv_heads
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False)
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_output = nn.Linear(q_proj_out, config.dim, bias=False)
# --- RMSNorms --------------------------------------------------------
self.attn_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(qk_norm, norm_eps), nn.RMSNorm(qk_norm, norm_eps)
self.attn_norm = nn.RMSNorm(config.dim, config.norm_eps)
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps)
# --- feed-forward (MoE or dense) -------------------------------------
if num_experts > 0:
self.norm_topk_prob = norm_topk_prob
self.num_experts_per_tok = num_experts_per_tok
self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router
self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim)
self.ffn_up_exps = ExpertWeights(num_experts, dim, hidden_dim)
self.ffn_down_exps = ExpertWeights(num_experts, hidden_dim, dim)
if config.num_experts > 0:
self.ffn_gate_inp = nn.Linear(config.dim, config.num_experts, bias=False) # router
self.ffn_gate_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
self.ffn_up_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
self.ffn_down_exps = ExpertWeights(config.num_experts, config.hidden_dim, config.dim)
else:
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
self.ffn_gate = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False)
@function
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
x_norm = self.attn_norm(x) # (B,T,D)
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
if self.config.qk_norm and self.config.qk_norm != self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
B, T, _ = x.shape
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
q = q.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) # (B,H,T,Hd)
k = k.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
if self.config.qk_norm == self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T])
k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T])
@@ -155,8 +165,8 @@ class TransformerBlock:
h_norm = self.ffn_norm(h)
if hasattr(self, 'ffn_gate_exps'):
x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each
if self.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.config.num_experts_per_tok) # (B, T, k) each
if self.config.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D)
return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
# TODO: remove the need for this contiguous
@@ -166,22 +176,20 @@ class TransformerBlock:
def __call__(self, x: Tensor, start_pos: int|UOp):
if not hasattr(self, "cache_kv"):
# TODO: how is the dtype of this determined?
self.cache_kv = Tensor.empty(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device)
self.freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)
self.cache_kv = Tensor.empty(2, x.shape[0], self.config.n_kv_heads, self.config.max_context, self.config.head_dim, device=x.device)
self.freqs_cis = precompute_freqs_cis(self.config.head_dim, self.config.max_context, self.config.rope_theta)
# we pass in the weights implicitly so we unpack the GGUF on the fly
@function(precompile=True, allow_implicit=True)
def _run(x:Tensor, start_pos:int|UOp): return self._feed_forward(self._attention(x, start_pos)).contiguous()
return _run(x, start_pos)
class Transformer:
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False):
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
num_experts, num_experts_per_tok, norm_topk_prob) for _ in range(num_blocks)]
self.token_embd = nn.Embedding(vocab_size, dim)
self.output_norm = nn.RMSNorm(dim, norm_eps)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.max_context = max_context
def __init__(self, config:TransformerConfig):
self.blk = [TransformerBlock(config) for _ in range(config.num_blocks)]
self.token_embd = nn.Embedding(config.vocab_size, config.dim)
self.output_norm = nn.RMSNorm(config.dim, config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.max_context = config.max_context
self._cached_tokens: list[int] = []
# we specialize the JIT for prefill and rollout
self.prefill_jit = TinyJit(self.forward)
@@ -218,15 +226,17 @@ class Transformer:
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
vocab_size=len(kv['tokenizer.ggml.tokens']),
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
norm_topk_prob=True if arch=='qwen3moe' else False)
config = TransformerConfig(
num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
vocab_size=len(kv['tokenizer.ggml.tokens']),
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
norm_topk_prob=arch == 'qwen3moe')
model = Transformer(config)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
if realize:
@@ -342,11 +352,10 @@ class Handler(HTTPRequestHandler):
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
if DEBUG >= 1: print(json.dumps(body, indent=2))
if self.path == "/v1/chat/completions":
# extract tokens
# extract tokens, last assistant message is treated as prefill
ids: list[int] = [bos_id] if bos_id is not None else []
for msg in body["messages"]:
for i, msg in enumerate(body["messages"]):
ids += tok.role(msg["role"])
# content can be a str or a list
content = msg["content"]
if isinstance(content, str): ids += tok.encode(content)
elif isinstance(content, list):
@@ -354,8 +363,9 @@ class Handler(HTTPRequestHandler):
if c["type"] == "text": ids += tok.encode(c["text"])
else: raise RuntimeError(f"unhandled type: {c['type']}")
else: raise RuntimeError(f"unknown content type: {type(content)}")
if msg["role"] == "assistant" and i == len(body["messages"]) - 1: break
ids += tok.end_turn(eos_id)
ids += tok.role("assistant")
else: ids += tok.role("assistant")
# reply
max_tokens = body.get("max_completion_tokens") or body.get("max_tokens")
@@ -383,8 +393,7 @@ if __name__ == "__main__":
# load the model
raw_model = Tensor.from_url(models[args.model])
model, kv = Transformer.from_gguf(raw_model, args.max_context)
if DEBUG >= 1 or args.benchmark:
print(f"using model {args.model} with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
print(f"using model {args.model} with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
del raw_model
# TODO: why this is required to free the RAM of the GGUF copy?