diff --git a/test/unit/test_llm_server.py b/test/unit/test_llm_server.py index c43fa86f77..350a414ced 100644 --- a/test/unit/test_llm_server.py +++ b/test/unit/test_llm_server.py @@ -60,25 +60,28 @@ class TestTransformerGenerate(unittest.TestCase): self.assertEqual(captured_inputs[0][1], 0) def test_two_prompts_schedule_cache(self): - """Second prompt prefill should hit the schedule cache, not miss.""" + """Third prompt should hit the schedule cache, not miss (first two warm up both jits: prefill + decode).""" from tinygrad.apps.llm import Transformer model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2, norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64) - # first prompt: prefill + a few decode steps + # first two prompts warm up both jits (prefill + decode) ids = list(range(1, 6)) gen = model.generate(ids) for _ in range(3): next(gen) - cache_size_after_first = len(schedule_cache) - # second prompt: simulates multi-turn chat (KV cache prefix is automatically reused) ids += list(range(10, 15)) gen = model.generate(ids) for _ in range(3): next(gen) + cache_size_after_warmup = len(schedule_cache) - # the second prompt should reuse the same schedule cache entries, not create new ones - self.assertEqual(cache_size_after_first, len(schedule_cache), - f"second prompt added {len(schedule_cache) - cache_size_after_first} new schedule cache entries (expected 0)") + # third prompt should reuse the same schedule cache entries, not create new ones + ids += list(range(20, 25)) + gen = model.generate(ids) + for _ in range(3): next(gen) + + self.assertEqual(cache_size_after_warmup, len(schedule_cache), + f"third prompt added {len(schedule_cache) - cache_size_after_warmup} new schedule cache entries (expected 0)") if __name__ == '__main__': unittest.main() diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 06e170051f..cc60282590 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, itertools from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function +from tinygrad.uop.ops import resolve from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler @@ -144,7 +145,7 @@ class TransformerBlock: # NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True # TODO: this if statement should be removed and it shouldn't generate extra kernels - mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) + mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd) attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D) attn = self.attn_output(attn) @@ -179,7 +180,9 @@ class Transformer: self.output = nn.Linear(dim, vocab_size, bias=False) self.max_context = max_context self._cached_tokens: list[int] = [] - self.forward_jit = TinyJit(self.forward) + # we specialize the JIT for prefill and rollout + self.prefill_jit = TinyJit(self.forward) + self.rollout_jit = TinyJit(self.forward) def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor: x = self.token_embd(tokens) # (B, T, D) @@ -187,7 +190,8 @@ class Transformer: # TODO: add temperature return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True) - def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor: return self.forward_jit(tokens, start_pos) + def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor: + return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens, start_pos) @staticmethod def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]: @@ -226,7 +230,7 @@ class Transformer: return model, kv def get_start_pos(self, tokens:list[int]): - return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens))) + return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens))) def generate(self, tokens:list[int], chunk_size:int=32): v_start_pos = UOp.variable("start_pos", 0, self.max_context-1) @@ -235,15 +239,13 @@ class Transformer: t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context) # recompute start_pos from what's currently valid in the kv cache start_pos = self.get_start_pos(tokens) + out = None while len(tokens) < self.max_context: sp, nt = v_start_pos.bind(start_pos), v_toks.bind(min(chunk_size, len(tokens) - start_pos)) - out = self(t[:, sp:sp+nt], sp) + out = self(t[:, sp:sp+nt] if out is None else out, sp).realize() start_pos += nt.val # chunked prefill: keep processing until all prompt tokens are consumed - if start_pos < len(tokens): - out.realize() - continue - t[:, sp+nt:sp+nt+1] = out + if start_pos < len(tokens): continue tokens.append(int(out.item())) self._cached_tokens = tokens[:] yield tokens[-1] @@ -275,7 +277,7 @@ CHAT_HTML = b'''tinygrad chat
- +