diff --git a/docs/abstractions2.py b/docs/abstractions2.py index 1971217f0e..ea8db48093 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -103,7 +103,7 @@ assert assign.src[1].op is Ops.KERNEL # schedule the kernel graph in a linear list s = UOp(Ops.SINK, dtypes.void, (assign,)) -sched, _, becomes_map = create_schedule_with_vars(s) +sched, _ = create_schedule_with_vars(s) assert len(sched) == 1 # DEBUGGING: print the compute ast @@ -111,7 +111,7 @@ print(sched[-1].ast) # NOTE: sched[-1].ast is the same as st_0 above # the output will be stored in a new buffer -out = becomes_map[assign] +out = assign.buf_uop assert out.op is Ops.BUFFER and not out.buffer.is_allocated() print(out) diff --git a/test/test_schedule.py b/test/test_schedule.py index ea622b6703..c10ca48fa5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -30,7 +30,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz assert isinstance(t, UOp), f"can't schedule {t}" sink = UOp.sink(t) if t.op is not Ops.SINK else t becomes_map = get_kernelize_map(sink) - sched, _, __ = create_schedule_with_vars(sink.substitute(becomes_map)) + sched, _ = create_schedule_with_vars(sink.substitute(becomes_map)) # test lowering all the ScheduleItems to ExecItems kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink]) if kernel_cnt != allowed: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 6b03b33e2e..4ffa2d9870 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -35,11 +35,11 @@ pm_unbind = PatternMatcher([ # **** schedule linearizer -def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: +def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]: # construct the KERNEL children graph based on assigns children: defaultdict[UOp, list[UOp]] = defaultdict(list) in_degree: dict[UOp, int] = {} - for u in (toposort:=sched_sink.toposort()): + for u in sched_sink.toposort(): if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip k = u.src[1] in_degree.setdefault(k, 0) @@ -94,8 +94,4 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) - # map ASSIGN to BUFFER after ScheduleItems are constructed - becomes_map = {u:u.buf_uop for u in toposort if u.op is Ops.ASSIGN} - assert all(u.op in {Ops.BUFFER, Ops.BUFFER_VIEW} for u in becomes_map.values()), f"Schedule didn't end with BUFFER {becomes_map.values()}" - - return schedule, var_vals, becomes_map + return schedule, var_vals diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8f8653413d..5dc666d792 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -247,8 +247,14 @@ class Tensor(MathTrait): """ st = time.perf_counter() self.kernelize(*lst) - schedule, var_vals, becomes_map = create_schedule_with_vars(UOp.sink(*[x.lazydata for x in (self,)+lst])) - _apply_map_to_tensors(becomes_map, name="Apply Schedule Map") + sink = UOp.sink(*[x.lazydata for x in (self,)+lst]) + + # remove all ASSIGNs, after scheduling, the tensors are just buffers + remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN} + _apply_map_to_tensors(remove_assign_map, name="Remove Assigns") + + # create the schedule + schedule, var_vals = create_schedule_with_vars(sink) schedule = memory_planner(schedule) if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms") return schedule, var_vals