mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
rope test
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import unittest, base64, functools, sys
|
||||
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re
|
||||
from tinygrad.apps.llm import SimpleTokenizer, get_llama_re, apply_rope
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad import Tensor, dtypes
|
||||
|
||||
@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows")
|
||||
class TestLLMTokenizer(unittest.TestCase):
|
||||
@@ -54,5 +55,13 @@ class TestLLMTokenizer(unittest.TestCase):
|
||||
def test_llama_repeat(self): self._test_coding(self.llama_tok, "00000000000000000", [ 931, 931, 931, 931, 931, 410 ])
|
||||
def test_llama_pat(self): self._test_coding(self.llama_tok, "today\n \n", [ 31213, 14211 ])
|
||||
|
||||
def test_apply_rope(self):
|
||||
x = Tensor.randn(2, 4, 8, 32, dtype=dtypes.float32)
|
||||
result = apply_rope(x, 0)
|
||||
self.assertEqual(result.shape, x.shape)
|
||||
self.assertEqual(result.dtype, x.dtype)
|
||||
self.assertGreater((result - apply_rope(x, 5)).abs().max().item(), 1e-6)
|
||||
with self.assertRaises(AssertionError): apply_rope(Tensor.randn(1, 1, 4, 7, dtype=dtypes.float32), 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user