mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* move more tests to test/null, split some existing ones * null work * null work * move more * fixes * move PIL * PIL in CLIP * don't move that
40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
import unittest
|
|
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
|
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
|
|
|
|
def apply_rope(x:Tensor, start_pos:int):
|
|
B, H, T, Hd = x.shape
|
|
precompute_freqs_cis.cache_clear()
|
|
freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T]
|
|
return apply_rope_new(x, freqs_cis)
|
|
|
|
class TestAttention(unittest.TestCase):
|
|
def test_half_qkv_buffers(self):
|
|
BS, seqlen, dim = 10, 4, 100
|
|
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
|
k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
|
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
|
attn = q.scaled_dot_product_attention(k, v)
|
|
sched = attn.schedule()
|
|
# attention has 4 kernels now
|
|
self.assertEqual(len(sched), 4)
|
|
|
|
def test_apply_rope_jit_prune(self):
|
|
def rope_fn(x_in, pos): return apply_rope(x_in, pos)
|
|
rope_noprune = TinyJit(rope_fn)
|
|
rope_prune = TinyJit(rope_fn, prune=True)
|
|
|
|
v_pos = UOp.variable("start_pos", 0, 100)
|
|
for _ in range(3):
|
|
rope_noprune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
|
rope_prune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
|
noprune_size = len(rope_noprune.captured.jit_cache)
|
|
prune_size = len(rope_prune.captured.jit_cache)
|
|
|
|
self.assertGreater(noprune_size, prune_size)
|
|
self.assertGreaterEqual(noprune_size, 2)
|
|
self.assertEqual(prune_size, 1)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|