remove the map from create_schedule_with_vars [pr] (#10472)

This commit is contained in:
George Hotz
2025-05-22 15:58:25 -07:00
committed by GitHub
parent 6d5f87a18a
commit 147f7747f2
4 changed files with 14 additions and 12 deletions

View File

@@ -103,7 +103,7 @@ assert assign.src[1].op is Ops.KERNEL
# schedule the kernel graph in a linear list
s = UOp(Ops.SINK, dtypes.void, (assign,))
sched, _, becomes_map = create_schedule_with_vars(s)
sched, _ = create_schedule_with_vars(s)
assert len(sched) == 1
# DEBUGGING: print the compute ast
@@ -111,7 +111,7 @@ print(sched[-1].ast)
# NOTE: sched[-1].ast is the same as st_0 above
# the output will be stored in a new buffer
out = becomes_map[assign]
out = assign.buf_uop
assert out.op is Ops.BUFFER and not out.buffer.is_allocated()
print(out)

View File

@@ -30,7 +30,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
assert isinstance(t, UOp), f"can't schedule {t}"
sink = UOp.sink(t) if t.op is not Ops.SINK else t
becomes_map = get_kernelize_map(sink)
sched, _, __ = create_schedule_with_vars(sink.substitute(becomes_map))
sched, _ = create_schedule_with_vars(sink.substitute(becomes_map))
# test lowering all the ScheduleItems to ExecItems
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
if kernel_cnt != allowed:

View File

@@ -35,11 +35,11 @@ pm_unbind = PatternMatcher([
# **** schedule linearizer
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]:
# construct the KERNEL children graph based on assigns
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree: dict[UOp, int] = {}
for u in (toposort:=sched_sink.toposort()):
for u in sched_sink.toposort():
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
k = u.src[1]
in_degree.setdefault(k, 0)
@@ -94,8 +94,4 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# map ASSIGN to BUFFER after ScheduleItems are constructed
becomes_map = {u:u.buf_uop for u in toposort if u.op is Ops.ASSIGN}
assert all(u.op in {Ops.BUFFER, Ops.BUFFER_VIEW} for u in becomes_map.values()), f"Schedule didn't end with BUFFER {becomes_map.values()}"
return schedule, var_vals, becomes_map
return schedule, var_vals

View File

@@ -247,8 +247,14 @@ class Tensor(MathTrait):
"""
st = time.perf_counter()
self.kernelize(*lst)
schedule, var_vals, becomes_map = create_schedule_with_vars(UOp.sink(*[x.lazydata for x in (self,)+lst]))
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
# remove all ASSIGNs, after scheduling, the tensors are just buffers
remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN}
_apply_map_to_tensors(remove_assign_map, name="Remove Assigns")
# create the schedule
schedule, var_vals = create_schedule_with_vars(sink)
schedule = memory_planner(schedule)
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
return schedule, var_vals