mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fast tinygrad.apps.llm (#13685)
* llm: add --benchmark support * fix speed * debug logging * fix test attention
This commit is contained in:
@@ -1,8 +1,14 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
||||
from tinygrad.apps.llm import apply_rope
|
||||
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
|
||||
#from tinygrad.engine.realize import run_schedule
|
||||
|
||||
def apply_rope(x:Tensor, start_pos:int):
|
||||
B, H, T, Hd = x.shape
|
||||
precompute_freqs_cis.cache_clear()
|
||||
freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T]
|
||||
return apply_rope_new(x, freqs_cis)
|
||||
|
||||
# TODO: test_scheduler, but just in uint
|
||||
class TestAttention(unittest.TestCase):
|
||||
def test_half_qkv_buffers(self):
|
||||
@@ -39,7 +45,7 @@ class TestAttention(unittest.TestCase):
|
||||
prune_size = len(rope_prune.captured.jit_cache)
|
||||
|
||||
self.assertGreater(noprune_size, prune_size)
|
||||
self.assertGreaterEqual(noprune_size, 3)
|
||||
self.assertGreaterEqual(noprune_size, 2)
|
||||
self.assertEqual(prune_size, 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user