diff --git a/docs/abstractions2.py b/docs/abstractions2.py index e4f76387b0..f09852c6eb 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -91,7 +91,7 @@ b = b.buf_uop_view() out = a.alu(Ops.ADD, b) # schedule the computation as a list of kernels -sched, _, becomes_map = create_schedule_with_vars([out]) +sched, _, becomes_map = create_schedule_with_vars(out.sink()) for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG # NOTE: UOps are no longer mutable, the scheduler gives you a map to lookup which BUFFER the result was written to out = becomes_map[out] diff --git a/test/test_schedule.py b/test/test_schedule.py index 357741a1fe..cf9fd16eb2 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -28,7 +28,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t) else: assert isinstance(t, UOp), f"can't schedule {t}" - sched, _, __ = create_schedule_with_vars([t]) + sched, _, __ = create_schedule_with_vars(t.sink()) # test lowering all the ScheduleItems to ExecItems lowered = list(lower_schedule(sched.copy())) if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)] @@ -2246,12 +2246,12 @@ class TestTensorUOpSpec(unittest.TestCase): ]) t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views+unsafe_push_views) with self.assertRaisesRegex(RuntimeError, "UOp verification failed"): - create_schedule_with_vars(list(t.src)) + create_schedule_with_vars(t) def test_expanded_const_ok(self): a = Tensor.ones((4, 4)) t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views) - create_schedule_with_vars(list(t.src)) + create_schedule_with_vars(t) class TestBufferUOp(unittest.TestCase): # BUFFER has a ShapeTracker of shape=(n,) and stride=(1,) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index db0ecccfe8..6ddb7084c0 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -519,8 +519,7 @@ remove_movement_ops = PatternMatcher([ ]) @track_rewrites(named=True) -def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: - big_sink = UOp.sink(*outs) +def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: # if using VIZ, do a graph rewrite to vizualize the Tensor graph if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ab89775cf1..a9fb7eb426 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -229,7 +229,8 @@ class Tensor(SimpleMathTrait): NOTE: A Tensor can only be scheduled once. """ - schedule, var_vals, becomes_map = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst])) + big_sink = UOp.sink(*flatten([x.lazydata.lbs for x in (self,)+lst])) + schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink) # get all children of keys in becomes_map all_uops: set[UOp] = set()