mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Add ContextVar to disable the scheduler cache (#14257)
* add scheduler cache ContextVar * test scheduler cache context var --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
import unittest
|
||||
import functools
|
||||
from tinygrad import Tensor, Variable, UOp, Context
|
||||
from tinygrad.helpers import cpu_events
|
||||
from tinygrad.uop.ops import KernelInfo
|
||||
from tinygrad.engine.schedule import schedule_cache
|
||||
|
||||
def custom_set0_kernel(A:UOp, num:int) -> UOp:
|
||||
return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}"))
|
||||
|
||||
def schedule_one():
|
||||
Tensor([1]).schedule()
|
||||
|
||||
class TestScheduleCache(unittest.TestCase):
|
||||
def test_bound_variable_reuses_cache(self):
|
||||
schedule_cache.clear()
|
||||
@@ -75,5 +79,27 @@ class TestScheduleCache(unittest.TestCase):
|
||||
print(num)
|
||||
self.assertEqual(len(schedule_cache), start_len_schedule_cache)
|
||||
|
||||
def test_disable_schedule_cache(self):
|
||||
schedule_cache.clear()
|
||||
|
||||
# test write
|
||||
with Context(SCACHE=0): schedule_one()
|
||||
self.assertEqual(len(schedule_cache), 0)
|
||||
with Context(SCACHE=1):
|
||||
schedule_one()
|
||||
schedule_one()
|
||||
self.assertEqual(len(schedule_cache), 1)
|
||||
|
||||
# test read
|
||||
with Context(PROFILE=1):
|
||||
cpu_events.clear()
|
||||
with Context(SCACHE=0): schedule_one()
|
||||
num_events_no_cache = len(cpu_events)
|
||||
|
||||
cpu_events.clear()
|
||||
with Context(SCACHE=1): schedule_one()
|
||||
num_events_cache = len(cpu_events)
|
||||
self.assertLess(num_events_cache, num_events_no_cache)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
|
||||
# **** schedule linearizer
|
||||
@@ -144,7 +144,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=(input_buffers, var_vals), name="rewrite for sched cache")
|
||||
sched_cache_key = big_sink_cache.key
|
||||
|
||||
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
|
||||
@@ -168,7 +168,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
|
||||
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map))
|
||||
combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink)
|
||||
schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
|
||||
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
|
||||
else:
|
||||
# schedule cache hit
|
||||
del big_sink_cache
|
||||
|
||||
@@ -199,6 +199,8 @@ TUPLE_ORDER = ContextVar("TUPLE_ORDER", 1)
|
||||
CCACHE = ContextVar("CCACHE", 1)
|
||||
# allow tf32 to be used on NVIDIA GPUs
|
||||
ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
|
||||
# set to 0 to disable the scheduler cache
|
||||
SCACHE = ContextVar("SCACHE", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
Reference in New Issue
Block a user