rope test

This commit is contained in:
Nino Risteski
2025-08-18 08:50:50 +02:00
parent f3a7e009a2
commit 69ede543d0

View File

@@ -1,6 +1,7 @@
import unittest, base64, functools, sys
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re, apply_rope
from tinygrad.helpers import fetch
from tinygrad import Tensor, dtypes
@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows")
class TestLLMTokenizer(unittest.TestCase):
@@ -54,5 +55,13 @@ class TestLLMTokenizer(unittest.TestCase):
def test_llama_repeat(self): self._test_coding(self.llama_tok, "00000000000000000", [ 931, 931, 931, 931, 931, 410 ])
def test_llama_pat(self): self._test_coding(self.llama_tok, "today\n \n", [ 31213, 14211 ])
def test_apply_rope(self):
x = Tensor.randn(2, 4, 8, 32, dtype=dtypes.float32)
result = apply_rope(x, 0)
self.assertEqual(result.shape, x.shape)
self.assertEqual(result.dtype, x.dtype)
self.assertGreater((result - apply_rope(x, 5)).abs().max().item(), 1e-6)
with self.assertRaises(AssertionError): apply_rope(Tensor.randn(1, 1, 4, 7, dtype=dtypes.float32), 0)
if __name__ == '__main__':
unittest.main()