mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
make method_cache account for compiler (#12156)
* make method_cache account for compiler * sorry
This commit is contained in:
@@ -5,9 +5,9 @@ from tinygrad.nn.state import get_state_dict
|
||||
|
||||
class TestMethodCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.backup_compiler = Device[Device.DEFAULT].compiler
|
||||
self.backup_compiler = Device[Device.DEFAULT].compiler.compile_cached
|
||||
def tearDown(self):
|
||||
Device[Device.DEFAULT].compiler = self.backup_compiler
|
||||
Device[Device.DEFAULT].compiler.compile_cached = self.backup_compiler
|
||||
|
||||
def test_simple_methodcache(self):
|
||||
a = Tensor([1])
|
||||
@@ -15,19 +15,19 @@ class TestMethodCache(unittest.TestCase):
|
||||
c = Tensor([3])
|
||||
d = Tensor([4])
|
||||
(a+b).realize()
|
||||
Device[Device.DEFAULT].compiler = None
|
||||
Device[Device.DEFAULT].compiler.compile_cached = None
|
||||
(c+d).realize()
|
||||
|
||||
def test_nested_methodcache(self):
|
||||
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
|
||||
((a+b)+(a+b)).realize()
|
||||
Device[Device.DEFAULT].compiler = None
|
||||
Device[Device.DEFAULT].compiler.compile_cached = None
|
||||
((c+d)+(c+d)).realize()
|
||||
|
||||
def test_nested_methodcache_swap(self):
|
||||
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
|
||||
((a+b)+(c+d)).realize()
|
||||
Device[Device.DEFAULT].compiler = None
|
||||
Device[Device.DEFAULT].compiler.compile_cached = None
|
||||
((c+d)+(a+b)).realize()
|
||||
|
||||
@unittest.skip("incorrect use of transformer")
|
||||
@@ -38,7 +38,7 @@ class TestMethodCache(unittest.TestCase):
|
||||
# NOTE: you have to do this twice due to the k-v cache
|
||||
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
|
||||
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
|
||||
Device[Device.DEFAULT].compiler = None
|
||||
Device[Device.DEFAULT].compiler.compile_cached = None
|
||||
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user