mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
clean up toposort sched_sink [pr] (#14439)
This commit is contained in:
@@ -9,38 +9,39 @@ from tinygrad.engine.realize import ExecItem
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
|
||||
def _unwrap_src(s: UOp) -> UOp:
|
||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||
return s
|
||||
|
||||
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||
# construct the KERNEL children graph based on assigns
|
||||
# build kernel dependency graph: edges from producer kernel to consumer kernels
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
if u.op is Ops.RANGE:
|
||||
in_degree.setdefault(u, 0)
|
||||
in_degree[u] = 0
|
||||
continue
|
||||
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
|
||||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
in_degree[k] = 0
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
s = _unwrap_src(s)
|
||||
if s.op is Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||
children.setdefault(ss.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op in {Ops.BUFFER, Ops.BIND}:
|
||||
pass # a BUFFER is already realized, BINDs are handled in complete_create_schedule_with_vars
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
||||
match (s := _unwrap_src(s)).op:
|
||||
case Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
case Ops.MSELECT | Ops.MSTACK:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||
children.setdefault(ss.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
case Ops.BUFFER | Ops.BIND:
|
||||
pass # BUFFER is already realized, BIND is outer range (handled via bound_ranges below)
|
||||
case _:
|
||||
raise RuntimeError(f"input to kernel must be AFTER, BUFFER, MSELECT, MSTACK, or BIND, not {s.op}")
|
||||
|
||||
with cpu_profile(TracingKey("linearize schedule")):
|
||||
queue: deque[UOp] = deque()
|
||||
|
||||
Reference in New Issue
Block a user