mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
start work on schedule cache (#13529)
* start work on schedule cache * local unique * schedule cache works * schedule cache cleanup * fix tests * preserve metadata * oops, fix cache * put that there * fix spec * always miss * why is that broken? * src[0].op * fix process replay * delete abstractions2 * reenable the actual schedule cache * metadata is best effort * fix JIT in examples/gradaccum_mnist.py * full jit * fixed and test is real
This commit is contained in:
2
test/external/external_uop_gc.py
vendored
2
test/external/external_uop_gc.py
vendored
@@ -1,5 +1,6 @@
|
||||
import gc
|
||||
from tinygrad import Tensor, UOp, Device, nn
|
||||
from tinygrad.engine.schedule import schedule_cache
|
||||
from tinygrad.engine.realize import method_cache, get_program
|
||||
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
|
||||
from tinygrad.uop.divandmod import fold_divmod_general
|
||||
@@ -68,6 +69,7 @@ if __name__ == "__main__":
|
||||
t()
|
||||
|
||||
# these caches will keep uops alive
|
||||
schedule_cache.clear()
|
||||
method_cache.clear()
|
||||
apply_movement_op.cache_clear()
|
||||
_apply_reshape.cache_clear()
|
||||
|
||||
@@ -829,6 +829,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
@unittest.skip("metadata is no longer promised to be exact with schedulecache")
|
||||
def test_complex_backward(self):
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
|
||||
24
test/unit/test_schedule_cache.py
Normal file
24
test/unit/test_schedule_cache.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.engine.schedule import schedule_cache
|
||||
|
||||
class TestScheduleCache(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
a = Tensor.ones(10).contiguous()
|
||||
b = Tensor.ones(10).contiguous()
|
||||
Tensor.realize(a, b)
|
||||
|
||||
# warm up
|
||||
for _ in range(2):
|
||||
num = (a.sum().contiguous()+b.sum().contiguous()).item()
|
||||
print(num)
|
||||
|
||||
# confirm schedule cache doesn't grow
|
||||
start_len_schedule_cache = len(schedule_cache)
|
||||
for _ in range(3):
|
||||
num = (a.sum().contiguous()+b.sum().contiguous()).item()
|
||||
print(num)
|
||||
self.assertEqual(len(schedule_cache), start_len_schedule_cache)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user