make method_cache account for compiler (#12156)

* make method_cache account for compiler

* sorry
This commit is contained in:
nimlgen
2025-09-13 17:00:11 +03:00
committed by GitHub
parent 0c392089d9
commit 92df52d79a
4 changed files with 15 additions and 15 deletions

View File

@@ -5,9 +5,9 @@ from tinygrad.nn.state import get_state_dict
class TestMethodCache(unittest.TestCase):
def setUp(self):
self.backup_compiler = Device[Device.DEFAULT].compiler
self.backup_compiler = Device[Device.DEFAULT].compiler.compile_cached
def tearDown(self):
Device[Device.DEFAULT].compiler = self.backup_compiler
Device[Device.DEFAULT].compiler.compile_cached = self.backup_compiler
def test_simple_methodcache(self):
a = Tensor([1])
@@ -15,19 +15,19 @@ class TestMethodCache(unittest.TestCase):
c = Tensor([3])
d = Tensor([4])
(a+b).realize()
Device[Device.DEFAULT].compiler = None
Device[Device.DEFAULT].compiler.compile_cached = None
(c+d).realize()
def test_nested_methodcache(self):
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
((a+b)+(a+b)).realize()
Device[Device.DEFAULT].compiler = None
Device[Device.DEFAULT].compiler.compile_cached = None
((c+d)+(c+d)).realize()
def test_nested_methodcache_swap(self):
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
((a+b)+(c+d)).realize()
Device[Device.DEFAULT].compiler = None
Device[Device.DEFAULT].compiler.compile_cached = None
((c+d)+(a+b)).realize()
@unittest.skip("incorrect use of transformer")
@@ -38,7 +38,7 @@ class TestMethodCache(unittest.TestCase):
# NOTE: you have to do this twice due to the k-v cache
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
Device[Device.DEFAULT].compiler = None
Device[Device.DEFAULT].compiler.compile_cached = None
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
if __name__ == '__main__':