[experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)

This commit is contained in:
hikettei
2024-07-05 21:26:18 +09:00
parent 29bf027f87
commit 0eee08b87c

View File

@@ -166,7 +166,7 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer], seen:List[LazyBuffer]=[]):
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
if tr in realizes:
# can only fuse contiguous
@@ -179,7 +179,9 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
if tr_next.op in ReduceOps: return group.add(r)
# can only fuse contiguous
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group)
if tr not in seen:
seen = [tr] + seen
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, seen=seen)
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
Dict[LazyBuffer, _LBScheduleItem]]: