From 6dcdff3bfda7c724617872b26de9447109c5f1f4 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 24 Jul 2024 19:07:10 +0800 Subject: [PATCH] share fusion behavior for r3 kernels (#5680) * use groups * this is the next one * should check the whole graph --- tinygrad/engine/schedule.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 74a44ca210..b67053b070 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -192,7 +192,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group""" if (tr, st) in cache: return cache.add((tr, st)) - if tr in realizes: + if tr in realizes and tr is not r: # can only fuse contiguous # max one reduceop per kernel if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r) @@ -204,11 +204,11 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r) _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache) -def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:Dict[LazyBuffer, Dict[LazyBuffer, None]],\ +def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\ realizes:Dict[LazyBuffer, None], group:Set[LazyBuffer]) -> Set[LazyBuffer]: # create a multi output kernel if the LazyBuffers can cleanly group cache: Set[LazyBuffer] = set() - rc_parents, rc_children = deque(group), deque(group) + rc_parents = deque(group) while rc_parents: if (p:=rc_parents.pop()) in cache: continue cache.add(p) @@ -216,17 +216,10 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff if p.op in ReduceOps: return set() rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r) # search descendants of the reduceop that can cleanly group - cache.clear() - realized_descendants: Set[LazyBuffer] = set() - while rc_children: - if (c:=rc_children.pop()) in cache: continue - cache.add(c) - if c.op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op: - realized_descendants.clear() - break - if c in realizes and c not in group: realized_descendants.add(c) - rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device) - return group.union(realized_descendants) + descendants: Set[LazyBuffer] = set() + for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache=set()) + if any(tr in group for tr in descendants): descendants.clear() + return group.union(descendants) def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): """create a graph for realizing the outputs"""