fast tinygrad.apps.llm (#13685)

* llm: add --benchmark support

* fix speed

* debug logging

* fix test attention
This commit is contained in:
George Hotz
2025-12-14 21:05:21 -05:00
committed by GitHub
parent 6cad622f59
commit 572ca80046
3 changed files with 38 additions and 19 deletions

View File

@@ -1,8 +1,14 @@
import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.apps.llm import apply_rope
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
#from tinygrad.engine.realize import run_schedule
def apply_rope(x:Tensor, start_pos:int):
B, H, T, Hd = x.shape
precompute_freqs_cis.cache_clear()
freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T]
return apply_rope_new(x, freqs_cis)
# TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase):
def test_half_qkv_buffers(self):
@@ -39,7 +45,7 @@ class TestAttention(unittest.TestCase):
prune_size = len(rope_prune.captured.jit_cache)
self.assertGreater(noprune_size, prune_size)
self.assertGreaterEqual(noprune_size, 3)
self.assertGreaterEqual(noprune_size, 2)
self.assertEqual(prune_size, 1)
if __name__ == '__main__':

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid, time
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG, Timing, GlobalCounters
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored
class SimpleTokenizer:
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
@@ -52,15 +52,18 @@ class SimpleTokenizer:
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode()
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor:
@functools.cache
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).contiguous()
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
B, H, T, Hd = x.shape
assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension"
half = Hd // 2
t_start_pos = start_pos if isinstance(start_pos, int) else Tensor(start_pos)
angles = (Tensor.arange(T, dtype="float32") + t_start_pos)[:, None] * (base ** (-(Tensor.arange(half, dtype="float32") / half)))[None, :]
# contiguous here allows RoPE to be pruned in the JIT
cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype).contiguous(), angles.sin().reshape(1, 1, T, half).cast(x.dtype).contiguous()
x_pairs = x.reshape(B, H, T, half, 2)
x_pairs = x.reshape(B, H, T, Hd//2, 2)
cos = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 0]
sin = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 1]
return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)
@@ -96,8 +99,10 @@ class TransformerBlock:
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
q = apply_rope(q, start_pos)
k = apply_rope(k, start_pos)
# TODO: make UOp have SupportsIndex
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context)[start_pos:start_pos+T] # type: ignore
q = apply_rope(q, freqs_cis)
k = apply_rope(k, freqs_cis)
# TODO: remove these kv cache realizes
if not hasattr(self, "cache_kv"):
@@ -115,7 +120,8 @@ class TransformerBlock:
def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h)
gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm)
# TODO: remove the need for this contiguous
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
return h + self.ffn_down(gated)
def __call__(self, x: Tensor, start_pos: int|UOp):
@@ -185,24 +191,27 @@ models = {
# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt
class Handler(HTTPRequestHandler):
def log_request(self, code='-', size='-'): pass
def run_model(self, ids:list[int], model_name:str, include_usage=False):
stderr_log(f"{self.path} {colored('--', 'BLACK')} in:{len(ids):5d} {colored('--', 'BLACK')} ")
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
out = []
for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1):
out: list[int] = []
st = time.perf_counter()
for next_id in model.generate(ids):
if len(out) == 0: stderr_log(f"prefill:{len(ids)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
if next_id == eos_id: break
out.append(next_id)
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
if include_usage:
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
stderr_log(f"out:{len(out):5d} {colored('--', 'BLACK')} gen: {len(out)/(time.perf_counter()-pt):4.0f} tok/s\n")
def do_POST(self):
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
if DEBUG >= 1:
print(self.path)
print(json.dumps(body, indent=2))
if DEBUG >= 1: print(json.dumps(body, indent=2))
if self.path == "/v1/chat/completions":
# extract tokens
ids = [bos_id]

View File

@@ -149,6 +149,10 @@ def getenv(key:str, default:Any=0): return type(default)(os.getenv(key, default)
def temp(x:str, append_user:bool=False) -> str:
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
def stderr_log(msg):
sys.stderr.write(msg)
sys.stderr.flush()
class Context(contextlib.ContextDecorator):
def __init__(self, **kwargs): self.kwargs = kwargs
def __enter__(self):