start work on schedule cache (#13529)

* start work on schedule cache

* local unique

* schedule cache works

* schedule cache cleanup

* fix tests

* preserve metadata

* oops, fix cache

* put that there

* fix spec

* always miss

* why is that broken?

* src[0].op

* fix process replay

* delete abstractions2

* reenable the actual schedule cache

* metadata is best effort

* fix JIT in examples/gradaccum_mnist.py

* full jit

* fixed and test is real
This commit is contained in:
George Hotz
2025-12-04 17:24:49 -08:00
committed by GitHub
parent 62e2fc5108
commit c5bd28e21d
5 changed files with 104 additions and 20 deletions

View File

@@ -1,5 +1,6 @@
import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.engine.schedule import schedule_cache
from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
from tinygrad.uop.divandmod import fold_divmod_general
@@ -68,6 +69,7 @@ if __name__ == "__main__":
t()
# these caches will keep uops alive
schedule_cache.clear()
method_cache.clear()
apply_movement_op.cache_clear()
_apply_reshape.cache_clear()

View File

@@ -829,6 +829,7 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@unittest.skip("metadata is no longer promised to be exact with schedulecache")
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()

View File

@@ -0,0 +1,24 @@
import unittest
from tinygrad import Tensor
from tinygrad.engine.schedule import schedule_cache
class TestScheduleCache(unittest.TestCase):
def test_simple(self):
a = Tensor.ones(10).contiguous()
b = Tensor.ones(10).contiguous()
Tensor.realize(a, b)
# warm up
for _ in range(2):
num = (a.sum().contiguous()+b.sum().contiguous()).item()
print(num)
# confirm schedule cache doesn't grow
start_len_schedule_cache = len(schedule_cache)
for _ in range(3):
num = (a.sum().contiguous()+b.sum().contiguous()).item()
print(num)
self.assertEqual(len(schedule_cache), start_len_schedule_cache)
if __name__ == "__main__":
unittest.main()