start work on schedule cache (#13529)

* start work on schedule cache

* local unique

* schedule cache works

* schedule cache cleanup

* fix tests

* preserve metadata

* oops, fix cache

* put that there

* fix spec

* always miss

* why is that broken?

* src[0].op

* fix process replay

* delete abstractions2

* reenable the actual schedule cache

* metadata is best effort

* fix JIT in examples/gradaccum_mnist.py

* full jit

* fixed and test is real
This commit is contained in:
George Hotz
2025-12-04 17:24:49 -08:00
committed by GitHub
parent 62e2fc5108
commit c5bd28e21d
5 changed files with 104 additions and 20 deletions

View File

@@ -1,5 +1,6 @@
import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.engine.schedule import schedule_cache
from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
from tinygrad.uop.divandmod import fold_divmod_general
@@ -68,6 +69,7 @@ if __name__ == "__main__":
t()
# these caches will keep uops alive
schedule_cache.clear()
method_cache.clear()
apply_movement_op.cache_clear()
_apply_reshape.cache_clear()

View File

@@ -829,6 +829,7 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@unittest.skip("metadata is no longer promised to be exact with schedulecache")
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()

View File

@@ -0,0 +1,24 @@
import unittest
from tinygrad import Tensor
from tinygrad.engine.schedule import schedule_cache
class TestScheduleCache(unittest.TestCase):
def test_simple(self):
a = Tensor.ones(10).contiguous()
b = Tensor.ones(10).contiguous()
Tensor.realize(a, b)
# warm up
for _ in range(2):
num = (a.sum().contiguous()+b.sum().contiguous()).item()
print(num)
# confirm schedule cache doesn't grow
start_len_schedule_cache = len(schedule_cache)
for _ in range(3):
num = (a.sum().contiguous()+b.sum().contiguous()).item()
print(num)
self.assertEqual(len(schedule_cache), start_len_schedule_cache)
if __name__ == "__main__":
unittest.main()

View File

@@ -3,6 +3,7 @@ from typing import cast
from dataclasses import dataclass, field, replace
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites
from tinygrad.uop.ops import PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
@@ -113,24 +114,77 @@ from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map
def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp):
if (ret:=ctx.get(b, None)) is None:
if b.op is Ops.BUFFER:
ctx[b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx)), b.src[1]))
else:
# TODO: flip args in CONST
assert b.op is Ops.CONST
ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx))))
return ret
pm_pre_sched_cache = PatternMatcher([
# replace input buffers
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# remove unique consts
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
])
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
if (ret:=ctx.get(b, None)) is None:
assert b.op is Ops.BUFFER
# if it's not in the cache, create a new buffer
ctx[b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
])
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
# verify Tensors match the spec
if SPEC: type_verify(big_sink, tensor_spec)
# replace all UNIQUE buffers with LUNIQUE
input_buffers: dict[UOp, UOp] = {}
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
sched_cache_key = big_sink_cache.key
# tensor map is what we return
tensor_map: dict[UOp, UOp] = {}
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
tensor_map |= get_multi_map(big_sink)
big_sink = big_sink.substitute(tensor_map, name="Apply Multi Map")
big_sink = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink.src]))
# hack to preserve metadata
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx={}, name="preserve metadata")
tensor_map |= get_rangeify_map(big_sink)
big_sink = big_sink.substitute(tensor_map, name="Apply Kernelize Map")
# tensor map is what we return
tensor_map: dict[UOp, UOp] = {}
if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()):
tensor_map |= get_multi_map(big_sink_cache)
big_sink_cache = big_sink_cache.substitute(tensor_map, name="Apply Multi Map")
big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src]))
tensor_map |= get_rangeify_map(big_sink_cache)
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
# save in schedule cache
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]))
schedule_cache[sched_cache_key] = (big_sink, tensor_map_sink)
else:
# schedule cache hit
del big_sink_cache
big_sink, tensor_map_sink = sc_ret
# replace all the LUNIQUEs with UNIQUEs
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
big_sink = graph_rewrite(big_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for sched cache")
tm_src = graph_rewrite(tensor_map_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for tensor map").src
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
# create the schedule
schedule, var_vals = create_schedule_with_vars(big_sink)
@@ -140,5 +194,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms ({len(UOpMetaClass.ucache)} uops in cache)")
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache)} uops in cache")
return tensor_map, schedule, var_vals

View File

@@ -5,7 +5,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap
from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
@@ -297,7 +297,7 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
# BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier)
# NOTE: this has been fixed up a bit
def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=True):
def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
#assert isinstance(x.tag, Flat), "bufferize must be flat"
size = prod(x.shape)
rngs = sorted(idx.ranges, key=lambda x: x.arg)
@@ -323,7 +323,7 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
assert sdtype.addrspace == AddrSpace.GLOBAL
outer_range = x.src[0].src[1]
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
# NOTE: this has the same number as the outer range, we need string ranges!
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
@@ -333,13 +333,13 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
do_store = buf.index(idx, dtype=sdtype).store(x.src[0], tag=x.tag).end(*rngs)
return buf.after(do_store)
if allow_locals:
# handle locals
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=next(unwrap(ctx)))
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=next(ctx))
do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
return buf.after(do_store.barrier())
@@ -356,7 +356,7 @@ def flatten_bufferize(x:UOp):
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda x, idx: bufferize_to_store(None, x, idx, allow_locals=False)),
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
# move RESHAPEs through MSELECT/MSTACK
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
@@ -503,7 +503,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
return kernel
split_kernels = PatternMatcher([
@@ -517,7 +517,7 @@ def tag_uop(ctx:list[UOp], x:UOp):
return x.replace(tag=(len(ctx)-1,))
add_tags = PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
])
@@ -560,7 +560,8 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
# bufferize -> store
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, bottom_up=True, name="bufferize to store")
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign