clean up toposort sched_sink [pr] (#14439)

This commit is contained in:
chenyu
2026-01-30 10:18:28 -05:00
committed by GitHub
parent 838cd078bc
commit 9eb449f882

View File

@@ -9,38 +9,39 @@ from tinygrad.engine.realize import ExecItem
# **** schedule linearizer
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
def _unwrap_src(s: UOp) -> UOp:
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
return s
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
with cpu_profile(TracingKey("toposort sched_sink")):
# construct the KERNEL children graph based on assigns
# build kernel dependency graph: edges from producer kernel to consumer kernels
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort():
if u.op is Ops.RANGE:
in_degree.setdefault(u, 0)
in_degree[u] = 0
continue
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
k = u.src[1]
in_degree.setdefault(k, 0)
in_degree[k] = 0
for s in k.src[0].src if k.op is Ops.END else k.src:
s = _unwrap_src(s)
if s.op is Ops.AFTER:
children.setdefault(s.src[1], []).append(k)
in_degree[k] += 1
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
children.setdefault(ss.src[1], []).append(k)
in_degree[k] += 1
elif s.op in {Ops.BUFFER, Ops.BIND}:
pass # a BUFFER is already realized, BINDs are handled in complete_create_schedule_with_vars
else:
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
match (s := _unwrap_src(s)).op:
case Ops.AFTER:
children.setdefault(s.src[1], []).append(k)
in_degree[k] += 1
case Ops.MSELECT | Ops.MSTACK:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
children.setdefault(ss.src[1], []).append(k)
in_degree[k] += 1
case Ops.BUFFER | Ops.BIND:
pass # BUFFER is already realized, BIND is outer range (handled via bound_ranges below)
case _:
raise RuntimeError(f"input to kernel must be AFTER, BUFFER, MSELECT, MSTACK, or BIND, not {s.op}")
with cpu_profile(TracingKey("linearize schedule")):
queue: deque[UOp] = deque()