From cf3ccb809ff5f365d6e53a8601db5462389a2eca Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 3 May 2024 17:16:34 +0300 Subject: [PATCH] refactor scheduler parents search (#4402) --- tinygrad/engine/schedule.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2d3bc529ec..07d028af87 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -167,16 +167,11 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul forced_realize = True break if len(realized_children) > 1: - for rc in realized_children: - rc_parents = deque(x.base for x in rc.srcs) - while rc_parents: - if (p:=rc_parents.pop()).realized or p.op is LoadOps.CONST: continue - if p is r: continue - # max one reduceop per kernel - if p.op in ReduceOps: - forced_realize = True - break - for x in p.srcs: rc_parents.append(x.base) + rc_parents = deque(realized_children) + while rc_parents and not forced_realize: + # max one reduceop per kernel + if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True + else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r) continue for tr_next in children[tr].keys(): if not tr_next.realized: