Files
tinygrad/test/unit/test_attention.py
George Hotz 572ca80046 fast tinygrad.apps.llm (#13685)
* llm: add --benchmark support

* fix speed

* debug logging

* fix test attention
2025-12-14 21:05:21 -05:00

53 lines
2.2 KiB
Python

import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
#from tinygrad.engine.realize import run_schedule
def apply_rope(x:Tensor, start_pos:int):
B, H, T, Hd = x.shape
precompute_freqs_cis.cache_clear()
freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T]
return apply_rope_new(x, freqs_cis)
# TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase):
def test_half_qkv_buffers(self):
BS, seqlen, dim = 10, 4, 100
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
attn = q.scaled_dot_product_attention(k, v)
sched = attn.schedule()
# attention has 4 kernels now
self.assertEqual(len(sched), 4)
# softmax_inputs = sched[1:4]
# for i,si in enumerate(softmax_inputs):
# assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=} in kernel {i}"
def test_apply_rope(self):
x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32)
result = apply_rope(x, 0)
self.assertEqual(result.shape, x.shape)
self.assertEqual(result.dtype, x.dtype)
self.assertGreater((result - apply_rope(x, 5)).abs().max().item(), 1e-6)
with self.assertRaises(AssertionError): apply_rope(Tensor.randn(1, 1, 4, 7, dtype=dtypes.float32), 0)
def test_apply_rope_jit_prune(self):
def rope_fn(x_in, pos): return apply_rope(x_in, pos)
rope_noprune = TinyJit(rope_fn)
rope_prune = TinyJit(rope_fn, prune=True)
v_pos = UOp.variable("start_pos", 0, 100)
for _ in range(3):
rope_noprune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
rope_prune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
noprune_size = len(rope_noprune.captured.jit_cache)
prune_size = len(rope_prune.captured.jit_cache)
self.assertGreater(noprune_size, prune_size)
self.assertGreaterEqual(noprune_size, 2)
self.assertEqual(prune_size, 1)
if __name__ == '__main__':
unittest.main()