mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
[experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)
This commit is contained in:
@@ -166,7 +166,7 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
||||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||
|
||||
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer], seen:List[LazyBuffer]=[]):
|
||||
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
|
||||
if tr in realizes:
|
||||
# can only fuse contiguous
|
||||
@@ -179,7 +179,9 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
|
||||
if tr_next.op in ReduceOps: return group.add(r)
|
||||
# can only fuse contiguous
|
||||
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
|
||||
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group)
|
||||
if tr not in seen:
|
||||
seen = [tr] + seen
|
||||
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, seen=seen)
|
||||
|
||||
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
|
||||
Dict[LazyBuffer, _LBScheduleItem]]:
|
||||
|
||||
Reference in New Issue
Block a user