Add ContextVar to disable the scheduler cache (#14257)

* add scheduler cache ContextVar

* test scheduler cache context var

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
imaolo
2026-01-27 06:55:29 -05:00
committed by GitHub
parent bfc88bcfb8
commit 14574c68fa
3 changed files with 31 additions and 3 deletions

View File

@@ -1,12 +1,16 @@
import unittest
import functools
from tinygrad import Tensor, Variable, UOp, Context
from tinygrad.helpers import cpu_events
from tinygrad.uop.ops import KernelInfo
from tinygrad.engine.schedule import schedule_cache
def custom_set0_kernel(A:UOp, num:int) -> UOp:
return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}"))
def schedule_one():
Tensor([1]).schedule()
class TestScheduleCache(unittest.TestCase):
def test_bound_variable_reuses_cache(self):
schedule_cache.clear()
@@ -75,5 +79,27 @@ class TestScheduleCache(unittest.TestCase):
print(num)
self.assertEqual(len(schedule_cache), start_len_schedule_cache)
def test_disable_schedule_cache(self):
schedule_cache.clear()
# test write
with Context(SCACHE=0): schedule_one()
self.assertEqual(len(schedule_cache), 0)
with Context(SCACHE=1):
schedule_one()
schedule_one()
self.assertEqual(len(schedule_cache), 1)
# test read
with Context(PROFILE=1):
cpu_events.clear()
with Context(SCACHE=0): schedule_one()
num_events_no_cache = len(cpu_events)
cpu_events.clear()
with Context(SCACHE=1): schedule_one()
num_events_cache = len(cpu_events)
self.assertLess(num_events_cache, num_events_no_cache)
if __name__ == "__main__":
unittest.main()

View File

@@ -4,7 +4,7 @@ from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
from tinygrad.engine.realize import ExecItem
# **** schedule linearizer
@@ -144,7 +144,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=(input_buffers, var_vals), name="rewrite for sched cache")
sched_cache_key = big_sink_cache.key
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
@@ -168,7 +168,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map))
combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink)
schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
else:
# schedule cache hit
del big_sink_cache

View File

@@ -199,6 +199,8 @@ TUPLE_ORDER = ContextVar("TUPLE_ORDER", 1)
CCACHE = ContextVar("CCACHE", 1)
# allow tf32 to be used on NVIDIA GPUs
ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
# set to 0 to disable the scheduler cache
SCACHE = ContextVar("SCACHE", 1)
@dataclass(frozen=True)
class Metadata: