import unittest from tinygrad import Tensor, dtypes from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis def apply_rope(x:Tensor, start_pos:int): B, H, T, Hd = x.shape precompute_freqs_cis.cache_clear() freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T] return apply_rope_new(x, freqs_cis) class TestAttention(unittest.TestCase): def test_apply_rope(self): x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32) result = apply_rope(x, 0) self.assertEqual(result.shape, x.shape) self.assertEqual(result.dtype, x.dtype) self.assertGreater((result - apply_rope(x, 5)).abs().max().item(), 1e-6) with self.assertRaises(AssertionError): apply_rope(Tensor.randn(1, 1, 4, 7, dtype=dtypes.float32), 0) if __name__ == '__main__': unittest.main()