llm: print the prefill cache size (#15146)

* print the llm prefill cache size

* mock that too
This commit is contained in:
George Hotz
2026-03-05 12:13:28 +08:00
committed by GitHub
parent b5370fd52d
commit 8a82b26522
2 changed files with 9 additions and 3 deletions

View File

@@ -14,6 +14,7 @@ class TestLLMServer(unittest.TestCase):
cls.mock_model = Mock()
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
cls.mock_model.get_start_pos = Mock(return_value=0)
cls.bos_id = 1
cls.eos_id = 999

View File

@@ -225,13 +225,16 @@ class Transformer:
Tensor.realize(*params)
return model, kv
def get_start_pos(self, tokens:list[int]):
return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens)))
def generate(self, tokens:list[int]):
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
v_toks = UOp.variable("toks", 1, self.max_context)
# assign all input tokens once, then slice from start_pos for the model call
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
# recompute start_pos from what's currently valid in the kv cache
start_pos = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens)))
start_pos = self.get_start_pos(tokens)
while len(tokens) < self.max_context:
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(len(tokens) - start_pos)
t[:, sp+nt:sp+nt+1] = out = self(t[:, sp:sp+nt], sp)
@@ -295,13 +298,15 @@ class Handler(HTTPRequestHandler):
def log_request(self, code='-', size='-'): pass
def do_GET(self): self.send_data(CHAT_HTML, content_type="text/html")
def run_model(self, ids:list[int], model_name:str, include_usage=False):
stderr_log(f"{self.path} {colored('--', 'BLACK')} in:{len(ids):5d} {colored('--', 'BLACK')} ")
cache_start_pos = model.get_start_pos(ids)
stderr_log(f"{self.path} {colored('--', 'BLACK')} "
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
out: list[int] = []
st = time.perf_counter()
for next_id in model.generate(ids):
if len(out) == 0: stderr_log(f"prefill:{len(ids)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
if next_id == eos_id: break
out.append(next_id)
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}