From 14574c68fa1230cb27875fbe72d9f31cff43fd0b Mon Sep 17 00:00:00 2001 From: imaolo <56898718+imaolo@users.noreply.github.com> Date: Tue, 27 Jan 2026 06:55:29 -0500 Subject: [PATCH] Add ContextVar to disable the scheduler cache (#14257) * add scheduler cache ContextVar * test scheduler cache context var --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- test/unit/test_schedule_cache.py | 26 ++++++++++++++++++++++++++ tinygrad/engine/schedule.py | 6 +++--- tinygrad/helpers.py | 2 ++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/test/unit/test_schedule_cache.py b/test/unit/test_schedule_cache.py index 0d56dafa86..71fb85b831 100644 --- a/test/unit/test_schedule_cache.py +++ b/test/unit/test_schedule_cache.py @@ -1,12 +1,16 @@ import unittest import functools from tinygrad import Tensor, Variable, UOp, Context +from tinygrad.helpers import cpu_events from tinygrad.uop.ops import KernelInfo from tinygrad.engine.schedule import schedule_cache def custom_set0_kernel(A:UOp, num:int) -> UOp: return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}")) +def schedule_one(): + Tensor([1]).schedule() + class TestScheduleCache(unittest.TestCase): def test_bound_variable_reuses_cache(self): schedule_cache.clear() @@ -75,5 +79,27 @@ class TestScheduleCache(unittest.TestCase): print(num) self.assertEqual(len(schedule_cache), start_len_schedule_cache) + def test_disable_schedule_cache(self): + schedule_cache.clear() + + # test write + with Context(SCACHE=0): schedule_one() + self.assertEqual(len(schedule_cache), 0) + with Context(SCACHE=1): + schedule_one() + schedule_one() + self.assertEqual(len(schedule_cache), 1) + + # test read + with Context(PROFILE=1): + cpu_events.clear() + with Context(SCACHE=0): schedule_one() + num_events_no_cache = len(cpu_events) + + cpu_events.clear() + with Context(SCACHE=1): schedule_one() + num_events_cache = len(cpu_events) + self.assertLess(num_events_cache, num_events_no_cache) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e0c4b534ab..ab622cbc18 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -4,7 +4,7 @@ from collections import deque from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize +from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE from tinygrad.engine.realize import ExecItem # **** schedule linearizer @@ -144,7 +144,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=(input_buffers, var_vals), name="rewrite for sched cache") sched_cache_key = big_sink_cache.key - if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None: + if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None: # verify Tensors match the spec (on big_sink, we only need to do this if cache misses) if SPEC: type_verify(big_sink, tensor_spec) @@ -168,7 +168,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER] tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map)) combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink) - schedule_cache[sched_cache_key] = (pre_schedule, combined_sink) + if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, combined_sink) else: # schedule cache hit del big_sink_cache diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index d4b0e7d053..1bc0578617 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -199,6 +199,8 @@ TUPLE_ORDER = ContextVar("TUPLE_ORDER", 1) CCACHE = ContextVar("CCACHE", 1) # allow tf32 to be used on NVIDIA GPUs ALLOW_TF32 = ContextVar("ALLOW_TF32", 0) +# set to 0 to disable the scheduler cache +SCACHE = ContextVar("SCACHE", 1) @dataclass(frozen=True) class Metadata: