diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 469c7d771c..8da404faed 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -166,7 +166,7 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool: return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], - realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]): + realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer], seen:List[LazyBuffer]=[]): """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group""" if tr in realizes: # can only fuse contiguous @@ -179,7 +179,9 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa if tr_next.op in ReduceOps: return group.add(r) # can only fuse contiguous if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r) - _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group) + if tr not in seen: + seen = [tr] + seen + _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, seen=seen) def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int], Dict[LazyBuffer, _LBScheduleItem]]: