diff --git a/test/unit/test_attention.py b/test/unit/test_attention.py index e47b74fbe4..99cf1c13ce 100644 --- a/test/unit/test_attention.py +++ b/test/unit/test_attention.py @@ -1,8 +1,14 @@ import unittest from tinygrad import Tensor, dtypes, TinyJit, UOp -from tinygrad.apps.llm import apply_rope +from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis #from tinygrad.engine.realize import run_schedule +def apply_rope(x:Tensor, start_pos:int): + B, H, T, Hd = x.shape + precompute_freqs_cis.cache_clear() + freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T] + return apply_rope_new(x, freqs_cis) + # TODO: test_scheduler, but just in uint class TestAttention(unittest.TestCase): def test_half_qkv_buffers(self): @@ -39,7 +45,7 @@ class TestAttention(unittest.TestCase): prune_size = len(rope_prune.captured.jit_cache) self.assertGreater(noprune_size, prune_size) - self.assertGreaterEqual(noprune_size, 3) + self.assertGreaterEqual(noprune_size, 2) self.assertEqual(prune_size, 1) if __name__ == '__main__': diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 1cb78c1dfa..071b1e3ae6 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -1,7 +1,7 @@ from __future__ import annotations -import sys, argparse, typing, re, unicodedata, json, uuid, time +import sys, argparse, typing, re, unicodedata, json, uuid, time, functools from tinygrad import Tensor, nn, UOp, TinyJit, getenv -from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG, Timing, GlobalCounters +from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored class SimpleTokenizer: def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]): @@ -52,15 +52,18 @@ class SimpleTokenizer: def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode() def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n") -def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor: +@functools.cache +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: + freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim)) + freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0) + return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).contiguous() + +def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor: B, H, T, Hd = x.shape assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension" - half = Hd // 2 - t_start_pos = start_pos if isinstance(start_pos, int) else Tensor(start_pos) - angles = (Tensor.arange(T, dtype="float32") + t_start_pos)[:, None] * (base ** (-(Tensor.arange(half, dtype="float32") / half)))[None, :] - # contiguous here allows RoPE to be pruned in the JIT - cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype).contiguous(), angles.sin().reshape(1, 1, T, half).cast(x.dtype).contiguous() - x_pairs = x.reshape(B, H, T, half, 2) + x_pairs = x.reshape(B, H, T, Hd//2, 2) + cos = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 0] + sin = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 1] return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin, x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd) @@ -96,8 +99,10 @@ class TransformerBlock: k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd) v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd) - q = apply_rope(q, start_pos) - k = apply_rope(k, start_pos) + # TODO: make UOp have SupportsIndex + freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context)[start_pos:start_pos+T] # type: ignore + q = apply_rope(q, freqs_cis) + k = apply_rope(k, freqs_cis) # TODO: remove these kv cache realizes if not hasattr(self, "cache_kv"): @@ -115,7 +120,8 @@ class TransformerBlock: def _feed_forward(self, h: Tensor) -> Tensor: h_norm = self.ffn_norm(h) - gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm) + # TODO: remove the need for this contiguous + gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm) return h + self.ffn_down(gated) def __call__(self, x: Tensor, start_pos: int|UOp): @@ -185,24 +191,27 @@ models = { # OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt class Handler(HTTPRequestHandler): + def log_request(self, code='-', size='-'): pass def run_model(self, ids:list[int], model_name:str, include_usage=False): + stderr_log(f"{self.path} {colored('--', 'BLACK')} in:{len(ids):5d} {colored('--', 'BLACK')} ") tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name} yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl} - out = [] - for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1): + out: list[int] = [] + st = time.perf_counter() + for next_id in model.generate(ids): + if len(out) == 0: stderr_log(f"prefill:{len(ids)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ") if next_id == eos_id: break out.append(next_id) yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl} yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl} if include_usage: yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl} + stderr_log(f"out:{len(out):5d} {colored('--', 'BLACK')} gen: {len(out)/(time.perf_counter()-pt):4.0f} tok/s\n") def do_POST(self): raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0"))) body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8")) - if DEBUG >= 1: - print(self.path) - print(json.dumps(body, indent=2)) + if DEBUG >= 1: print(json.dumps(body, indent=2)) if self.path == "/v1/chat/completions": # extract tokens ids = [bos_id] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 905fd08144..324764b296 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -149,6 +149,10 @@ def getenv(key:str, default:Any=0): return type(default)(os.getenv(key, default) def temp(x:str, append_user:bool=False) -> str: return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix() +def stderr_log(msg): + sys.stderr.write(msg) + sys.stderr.flush() + class Context(contextlib.ContextDecorator): def __init__(self, **kwargs): self.kwargs = kwargs def __enter__(self):