From 168c16646aac9bf6858b03da55d96e3968f94fe4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 19 Jan 2025 13:30:26 -0800 Subject: [PATCH] change create_schedule_with_vars api to big_sink [pr] (#8677) --- docs/abstractions2.py | 2 +- test/test_schedule.py | 6 +++--- tinygrad/engine/schedule.py | 3 +-- tinygrad/tensor.py | 3 ++- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/abstractions2.py b/docs/abstractions2.py index e4f76387b0..f09852c6eb 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -91,7 +91,7 @@ b = b.buf_uop_view() out = a.alu(Ops.ADD, b) # schedule the computation as a list of kernels -sched, _, becomes_map = create_schedule_with_vars([out]) +sched, _, becomes_map = create_schedule_with_vars(out.sink()) for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG # NOTE: UOps are no longer mutable, the scheduler gives you a map to lookup which BUFFER the result was written to out = becomes_map[out] diff --git a/test/test_schedule.py b/test/test_schedule.py index 357741a1fe..cf9fd16eb2 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -28,7 +28,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t) else: assert isinstance(t, UOp), f"can't schedule {t}" - sched, _, __ = create_schedule_with_vars([t]) + sched, _, __ = create_schedule_with_vars(t.sink()) # test lowering all the ScheduleItems to ExecItems lowered = list(lower_schedule(sched.copy())) if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)] @@ -2246,12 +2246,12 @@ class TestTensorUOpSpec(unittest.TestCase): ]) t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views+unsafe_push_views) with self.assertRaisesRegex(RuntimeError, "UOp verification failed"): - create_schedule_with_vars(list(t.src)) + create_schedule_with_vars(t) def test_expanded_const_ok(self): a = Tensor.ones((4, 4)) t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views) - create_schedule_with_vars(list(t.src)) + create_schedule_with_vars(t) class TestBufferUOp(unittest.TestCase): # BUFFER has a ShapeTracker of shape=(n,) and stride=(1,) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index db0ecccfe8..6ddb7084c0 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -519,8 +519,7 @@ remove_movement_ops = PatternMatcher([ ]) @track_rewrites(named=True) -def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: - big_sink = UOp.sink(*outs) +def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: # if using VIZ, do a graph rewrite to vizualize the Tensor graph if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ab89775cf1..a9fb7eb426 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -229,7 +229,8 @@ class Tensor(SimpleMathTrait): NOTE: A Tensor can only be scheduled once. """ - schedule, var_vals, becomes_map = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst])) + big_sink = UOp.sink(*flatten([x.lazydata.lbs for x in (self,)+lst])) + schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink) # get all children of keys in becomes_map all_uops: set[UOp] = set()