mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* move more tests to test/null, split some existing ones * null work * null work * move more * fixes * move PIL * PIL in CLIP * don't move that
22 lines
821 B
Python
22 lines
821 B
Python
import unittest
|
|
from tinygrad import Tensor, dtypes
|
|
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
|
|
|
|
def apply_rope(x:Tensor, start_pos:int):
|
|
B, H, T, Hd = x.shape
|
|
precompute_freqs_cis.cache_clear()
|
|
freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T]
|
|
return apply_rope_new(x, freqs_cis)
|
|
|
|
class TestAttention(unittest.TestCase):
|
|
def test_apply_rope(self):
|
|
x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32)
|
|
result = apply_rope(x, 0)
|
|
self.assertEqual(result.shape, x.shape)
|
|
self.assertEqual(result.dtype, x.dtype)
|
|
self.assertGreater((result - apply_rope(x, 5)).abs().max().item(), 1e-6)
|
|
with self.assertRaises(AssertionError): apply_rope(Tensor.randn(1, 1, 4, 7, dtype=dtypes.float32), 0)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|