diff --git a/test/test_schedule.py b/test/test_schedule.py index 722222694a..608aef72de 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -16,7 +16,7 @@ from tinygrad.shape.view import View from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.codegen.kernel import Kernel, verify_ast -from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, do_realize +from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule from tinygrad.engine.lazy import LazyBuffer from extra.models.llama import precompute_freqs_cis @@ -1938,7 +1938,7 @@ class TestView(unittest.TestCase): np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:]) @track_rewrites(named=True) -def big_graph_rewrite(big_graph:UOp, realizes={}) -> UOp: return graph_rewrite(big_graph, do_realize, realizes) +def big_graph_rewrite(big_graph:UOp, realizes={}) -> UOp: return graph_rewrite(big_graph, do_realize, ScheduleContext(realizes=realizes)) class TestBigGraph(unittest.TestCase): def test_sink_childless_const(self): x = UOp.const(dtypes.int, 0) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 8c896ce144..ebfa46f7a1 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -358,23 +358,23 @@ ops_folding = PatternMatcher([ # ** this decides which ops get realized -def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, **kwargs) -> None: - if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.update([(b, to_store)]) +def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: + if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.realizes.update([(b, to_store)]) -def realize_view(ctx:Dict[UOp, UOp], view:UOp, src:UOp, b:UOp, **kwargs) -> None: +def realize_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: if src.st is None: return None st = unwrap(view.st) # fold simple pads if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])): - return None if can_pad(src, ctx, set()) else realize(ctx, b, src) + return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src) # early realize before expand if resolve(prod(src.shape) < prod(st.shape)): return realize(ctx, b, src) # otherwise safety check pads - return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx, set())) else realize(ctx, b, src) + return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src) -def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]: - if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None - del ctx[b] +def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]: + if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(to_cast).op in GroupOp.Meta: return None + del ctx.realizes[b] return to_cast.view(unwrap(view.st)) def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]: @@ -436,7 +436,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] cache: Dict[LazyBuffer, UOp] = {} buffers: Dict[UOp, Buffer] = {} for u in (big_graph:=UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs))).src: ctx.realizes[u.buf_uop] = u - big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx.realizes) + big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx) # create the scheduler context graph_rewrite(big_graph, create_ctx, ctx) # group realizes into kernels