diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d01ed73eaa..339bc0103d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -9,38 +9,39 @@ from tinygrad.engine.realize import ExecItem # **** schedule linearizer +# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op) def _unwrap_src(s: UOp) -> UOp: while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0] return s def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: with cpu_profile(TracingKey("toposort sched_sink")): - # construct the KERNEL children graph based on assigns + # build kernel dependency graph: edges from producer kernel to consumer kernels children: dict[UOp, list[UOp]] = {} in_degree: dict[UOp, int] = {} for u in sched_sink.toposort(): if u.op is Ops.RANGE: - in_degree.setdefault(u, 0) + in_degree[u] = 0 continue if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue k = u.src[1] - in_degree.setdefault(k, 0) + in_degree[k] = 0 for s in k.src[0].src if k.op is Ops.END else k.src: - s = _unwrap_src(s) - if s.op is Ops.AFTER: - children.setdefault(s.src[1], []).append(k) - in_degree[k] += 1 - elif s.op in {Ops.MSELECT, Ops.MSTACK}: - for ss in s.src: - if ss.op is Ops.MSELECT: ss = ss.src[0] - if ss.op is not Ops.BUFFER: - assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}" - children.setdefault(ss.src[1], []).append(k) - in_degree[k] += 1 - elif s.op in {Ops.BUFFER, Ops.BIND}: - pass # a BUFFER is already realized, BINDs are handled in complete_create_schedule_with_vars - else: - raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}") + match (s := _unwrap_src(s)).op: + case Ops.AFTER: + children.setdefault(s.src[1], []).append(k) + in_degree[k] += 1 + case Ops.MSELECT | Ops.MSTACK: + for ss in s.src: + if ss.op is Ops.MSELECT: ss = ss.src[0] + if ss.op is not Ops.BUFFER: + assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}" + children.setdefault(ss.src[1], []).append(k) + in_degree[k] += 1 + case Ops.BUFFER | Ops.BIND: + pass # BUFFER is already realized, BIND is outer range (handled via bound_ranges below) + case _: + raise RuntimeError(f"input to kernel must be AFTER, BUFFER, MSELECT, MSTACK, or BIND, not {s.op}") with cpu_profile(TracingKey("linearize schedule")): queue: deque[UOp] = deque()