LLM speedup with two jits, prefill/rollout (#15153)

* START_TIME

* print cleanup

* fix tests
This commit is contained in:
George Hotz
2026-03-05 16:21:09 +08:00
committed by GitHub
parent be23772d43
commit e97922a57c
4 changed files with 31 additions and 23 deletions

View File

@@ -60,25 +60,28 @@ class TestTransformerGenerate(unittest.TestCase):
self.assertEqual(captured_inputs[0][1], 0)
def test_two_prompts_schedule_cache(self):
"""Second prompt prefill should hit the schedule cache, not miss."""
"""Third prompt should hit the schedule cache, not miss (first two warm up both jits: prefill + decode)."""
from tinygrad.apps.llm import Transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64)
# first prompt: prefill + a few decode steps
# first two prompts warm up both jits (prefill + decode)
ids = list(range(1, 6))
gen = model.generate(ids)
for _ in range(3): next(gen)
cache_size_after_first = len(schedule_cache)
# second prompt: simulates multi-turn chat (KV cache prefix is automatically reused)
ids += list(range(10, 15))
gen = model.generate(ids)
for _ in range(3): next(gen)
cache_size_after_warmup = len(schedule_cache)
# the second prompt should reuse the same schedule cache entries, not create new ones
self.assertEqual(cache_size_after_first, len(schedule_cache),
f"second prompt added {len(schedule_cache) - cache_size_after_first} new schedule cache entries (expected 0)")
# third prompt should reuse the same schedule cache entries, not create new ones
ids += list(range(20, 25))
gen = model.generate(ids)
for _ in range(3): next(gen)
self.assertEqual(cache_size_after_warmup, len(schedule_cache),
f"third prompt added {len(schedule_cache) - cache_size_after_warmup} new schedule cache entries (expected 0)")
if __name__ == '__main__':
unittest.main()

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, itertools
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
from tinygrad.uop.ops import resolve
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
@@ -144,7 +145,7 @@ class TransformerBlock:
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
# TODO: this if statement should be removed and it shouldn't generate extra kernels
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1)
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
attn = self.attn_output(attn)
@@ -179,7 +180,9 @@ class Transformer:
self.output = nn.Linear(dim, vocab_size, bias=False)
self.max_context = max_context
self._cached_tokens: list[int] = []
self.forward_jit = TinyJit(self.forward)
# we specialize the JIT for prefill and rollout
self.prefill_jit = TinyJit(self.forward)
self.rollout_jit = TinyJit(self.forward)
def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
x = self.token_embd(tokens) # (B, T, D)
@@ -187,7 +190,8 @@ class Transformer:
# TODO: add temperature
return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True)
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor: return self.forward_jit(tokens, start_pos)
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens, start_pos)
@staticmethod
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]:
@@ -226,7 +230,7 @@ class Transformer:
return model, kv
def get_start_pos(self, tokens:list[int]):
return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens)))
return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens)))
def generate(self, tokens:list[int], chunk_size:int=32):
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
@@ -235,15 +239,13 @@ class Transformer:
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
# recompute start_pos from what's currently valid in the kv cache
start_pos = self.get_start_pos(tokens)
out = None
while len(tokens) < self.max_context:
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(min(chunk_size, len(tokens) - start_pos))
out = self(t[:, sp:sp+nt], sp)
out = self(t[:, sp:sp+nt] if out is None else out, sp).realize()
start_pos += nt.val
# chunked prefill: keep processing until all prompt tokens are consumed
if start_pos < len(tokens):
out.realize()
continue
t[:, sp+nt:sp+nt+1] = out
if start_pos < len(tokens): continue
tokens.append(int(out.item()))
self._cached_tokens = tokens[:]
yield tokens[-1]
@@ -275,7 +277,7 @@ CHAT_HTML = b'''<!DOCTYPE html><html><head><title>tinygrad chat</title><style>
background: #2f2f2f; color: inherit; font: inherit;
border: none; outline: none; resize: none; border-radius: 24px; field-sizing: content }
</style></head><body><div id="chat"></div>
<textarea id="input" rows="1" placeholder="Ask anything"></textarea>
<textarea id="input" rows="1" placeholder="Ask anything" autofocus></textarea>
<script>
input.onkeydown = (e) => { if (e.key === 'Enter' && !e.shiftKey && !e.isComposing) { e.preventDefault(); send() } }
const msgs = [];
@@ -318,7 +320,7 @@ class Handler(HTTPRequestHandler):
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
if include_usage:
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
stderr_log(f"out:{len(out):5d} {colored('--', 'BLACK')} gen: {len(out)/(time.perf_counter()-pt):4.0f} tok/s\n")
stderr_log(f"gen:{len(out)/(time.perf_counter()-pt):4.0f} tok/s {colored('--', 'BLACK')} out:{len(out):5d}\n")
def do_POST(self):
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
@@ -387,8 +389,9 @@ if __name__ == "__main__":
# start server
if args.serve:
# warmup: run 2 tokens through the model to capture the JIT before serving
with Context(DEBUG=max(DEBUG.value, 1)): list(zip(range(2), model.generate([0])))
# warmup: run 2 tokens through the model twice to capture the JIT before serving
with Context(DEBUG=max(DEBUG.value, 1)):
for _ in range(2): list(zip(range(2), model.generate([0])))
TCPServerWithReuse(('', args.serve), Handler).serve_forever()
# interactive chat

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc
import time
START_TIME = time.perf_counter()
import os, functools, platform, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc
from collections import defaultdict
import subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools
from dataclasses import dataclass, field

View File

@@ -7,7 +7,7 @@ from urllib.parse import parse_qs, urlparse
from http.server import BaseHTTPRequestHandler
from typing import Any, TypedDict, TypeVar, Generator, Callable
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.helpers import printable, Context
from tinygrad.helpers import printable, Context, START_TIME
from tinygrad.renderer.amd.dsl import Inst
from tinygrad.renderer.amd import detect_format
@@ -15,7 +15,7 @@ from tinygrad.renderer.amd import detect_format
class TCPServerWithReuse(socketserver.TCPServer):
allow_reuse_address = True
def __init__(self, server_address, RequestHandlerClass):
print(f"*** started server on http://127.0.0.1:{server_address[1]}")
print(f"*** started server on http://127.0.0.1:{server_address[1]} at {time.perf_counter()-START_TIME:.2f} s")
super().__init__(server_address, RequestHandlerClass)
class HTTPRequestHandler(BaseHTTPRequestHandler):