From 8a82b26522cc00d82a8902240b895326809215ed Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 5 Mar 2026 12:13:28 +0800 Subject: [PATCH] llm: print the prefill cache size (#15146) * print the llm prefill cache size * mock that too --- test/null/test_llm_server.py | 1 + tinygrad/apps/llm.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/test/null/test_llm_server.py b/test/null/test_llm_server.py index 942baea061..8ddd0df556 100644 --- a/test/null/test_llm_server.py +++ b/test/null/test_llm_server.py @@ -14,6 +14,7 @@ class TestLLMServer(unittest.TestCase): cls.mock_model = Mock() cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999])) + cls.mock_model.get_start_pos = Mock(return_value=0) cls.bos_id = 1 cls.eos_id = 999 diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 6428da0ea5..d82f2a3a97 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -225,13 +225,16 @@ class Transformer: Tensor.realize(*params) return model, kv + def get_start_pos(self, tokens:list[int]): + return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens))) + def generate(self, tokens:list[int]): v_start_pos = UOp.variable("start_pos", 0, self.max_context-1) v_toks = UOp.variable("toks", 1, self.max_context) # assign all input tokens once, then slice from start_pos for the model call t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context) # recompute start_pos from what's currently valid in the kv cache - start_pos = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens))) + start_pos = self.get_start_pos(tokens) while len(tokens) < self.max_context: sp, nt = v_start_pos.bind(start_pos), v_toks.bind(len(tokens) - start_pos) t[:, sp+nt:sp+nt+1] = out = self(t[:, sp:sp+nt], sp) @@ -295,13 +298,15 @@ class Handler(HTTPRequestHandler): def log_request(self, code='-', size='-'): pass def do_GET(self): self.send_data(CHAT_HTML, content_type="text/html") def run_model(self, ids:list[int], model_name:str, include_usage=False): - stderr_log(f"{self.path} {colored('--', 'BLACK')} in:{len(ids):5d} {colored('--', 'BLACK')} ") + cache_start_pos = model.get_start_pos(ids) + stderr_log(f"{self.path} {colored('--', 'BLACK')} " + f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ") tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name} yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl} out: list[int] = [] st = time.perf_counter() for next_id in model.generate(ids): - if len(out) == 0: stderr_log(f"prefill:{len(ids)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ") + if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ") if next_id == eos_id: break out.append(next_id) yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}