mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llm: print the prefill cache size (#15146)
* print the llm prefill cache size * mock that too
This commit is contained in:
@@ -14,6 +14,7 @@ class TestLLMServer(unittest.TestCase):
|
||||
|
||||
cls.mock_model = Mock()
|
||||
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
|
||||
cls.mock_model.get_start_pos = Mock(return_value=0)
|
||||
|
||||
cls.bos_id = 1
|
||||
cls.eos_id = 999
|
||||
|
||||
@@ -225,13 +225,16 @@ class Transformer:
|
||||
Tensor.realize(*params)
|
||||
return model, kv
|
||||
|
||||
def get_start_pos(self, tokens:list[int]):
|
||||
return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens)))
|
||||
|
||||
def generate(self, tokens:list[int]):
|
||||
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
|
||||
v_toks = UOp.variable("toks", 1, self.max_context)
|
||||
# assign all input tokens once, then slice from start_pos for the model call
|
||||
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
|
||||
# recompute start_pos from what's currently valid in the kv cache
|
||||
start_pos = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens)))
|
||||
start_pos = self.get_start_pos(tokens)
|
||||
while len(tokens) < self.max_context:
|
||||
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(len(tokens) - start_pos)
|
||||
t[:, sp+nt:sp+nt+1] = out = self(t[:, sp:sp+nt], sp)
|
||||
@@ -295,13 +298,15 @@ class Handler(HTTPRequestHandler):
|
||||
def log_request(self, code='-', size='-'): pass
|
||||
def do_GET(self): self.send_data(CHAT_HTML, content_type="text/html")
|
||||
def run_model(self, ids:list[int], model_name:str, include_usage=False):
|
||||
stderr_log(f"{self.path} {colored('--', 'BLACK')} in:{len(ids):5d} {colored('--', 'BLACK')} ")
|
||||
cache_start_pos = model.get_start_pos(ids)
|
||||
stderr_log(f"{self.path} {colored('--', 'BLACK')} "
|
||||
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
|
||||
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
|
||||
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
|
||||
out: list[int] = []
|
||||
st = time.perf_counter()
|
||||
for next_id in model.generate(ids):
|
||||
if len(out) == 0: stderr_log(f"prefill:{len(ids)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
|
||||
if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
|
||||
if next_id == eos_id: break
|
||||
out.append(next_id)
|
||||
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
|
||||
|
||||
Reference in New Issue
Block a user