fix test attention

This commit is contained in:
George Hotz
2025-12-14 20:45:19 -05:00
parent d4385d62d3
commit 63c16143fc

View File

@@ -1,8 +1,14 @@
import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.apps.llm import apply_rope
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
#from tinygrad.engine.realize import run_schedule
def apply_rope(x:Tensor, start_pos:int):
B, H, T, Hd = x.shape
precompute_freqs_cis.cache_clear()
freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T]
return apply_rope_new(x, freqs_cis)
# TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase):
def test_half_qkv_buffers(self):
@@ -39,7 +45,7 @@ class TestAttention(unittest.TestCase):
prune_size = len(rope_prune.captured.jit_cache)
self.assertGreater(noprune_size, prune_size)
self.assertGreaterEqual(noprune_size, 3)
self.assertGreaterEqual(noprune_size, 2)
self.assertEqual(prune_size, 1)
if __name__ == '__main__':