mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
clean up toposort sched_sink [pr] (#14439)
This commit is contained in:
@@ -9,38 +9,39 @@ from tinygrad.engine.realize import ExecItem
|
|||||||
|
|
||||||
# **** schedule linearizer
|
# **** schedule linearizer
|
||||||
|
|
||||||
|
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
|
||||||
def _unwrap_src(s: UOp) -> UOp:
|
def _unwrap_src(s: UOp) -> UOp:
|
||||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||||
# construct the KERNEL children graph based on assigns
|
# build kernel dependency graph: edges from producer kernel to consumer kernels
|
||||||
children: dict[UOp, list[UOp]] = {}
|
children: dict[UOp, list[UOp]] = {}
|
||||||
in_degree: dict[UOp, int] = {}
|
in_degree: dict[UOp, int] = {}
|
||||||
for u in sched_sink.toposort():
|
for u in sched_sink.toposort():
|
||||||
if u.op is Ops.RANGE:
|
if u.op is Ops.RANGE:
|
||||||
in_degree.setdefault(u, 0)
|
in_degree[u] = 0
|
||||||
continue
|
continue
|
||||||
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
|
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
|
||||||
k = u.src[1]
|
k = u.src[1]
|
||||||
in_degree.setdefault(k, 0)
|
in_degree[k] = 0
|
||||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||||
s = _unwrap_src(s)
|
match (s := _unwrap_src(s)).op:
|
||||||
if s.op is Ops.AFTER:
|
case Ops.AFTER:
|
||||||
children.setdefault(s.src[1], []).append(k)
|
children.setdefault(s.src[1], []).append(k)
|
||||||
in_degree[k] += 1
|
in_degree[k] += 1
|
||||||
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
|
case Ops.MSELECT | Ops.MSTACK:
|
||||||
for ss in s.src:
|
for ss in s.src:
|
||||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||||
if ss.op is not Ops.BUFFER:
|
if ss.op is not Ops.BUFFER:
|
||||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||||
children.setdefault(ss.src[1], []).append(k)
|
children.setdefault(ss.src[1], []).append(k)
|
||||||
in_degree[k] += 1
|
in_degree[k] += 1
|
||||||
elif s.op in {Ops.BUFFER, Ops.BIND}:
|
case Ops.BUFFER | Ops.BIND:
|
||||||
pass # a BUFFER is already realized, BINDs are handled in complete_create_schedule_with_vars
|
pass # BUFFER is already realized, BIND is outer range (handled via bound_ranges below)
|
||||||
else:
|
case _:
|
||||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
raise RuntimeError(f"input to kernel must be AFTER, BUFFER, MSELECT, MSTACK, or BIND, not {s.op}")
|
||||||
|
|
||||||
with cpu_profile(TracingKey("linearize schedule")):
|
with cpu_profile(TracingKey("linearize schedule")):
|
||||||
queue: deque[UOp] = deque()
|
queue: deque[UOp] = deque()
|
||||||
|
|||||||
Reference in New Issue
Block a user