diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 71c75d9fc1..ba8104215e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -184,38 +184,35 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool: return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], - realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer], cache:Set): + realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], cache:Set): """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group""" if (tr, st) in cache: return cache.add((tr, st)) if tr in realizes and tr is not r: # can only fuse contiguous # max one reduceop per kernel - if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r) - return group.add(tr) + if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.setdefault(r) + return group.setdefault(tr) for tr_next in children[tr]: # max one reduceop per kernel - if tr_next.op in ReduceOps: return group.add(r) + if tr_next.op in ReduceOps: return group.setdefault(r) # can only fuse contiguous - if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r) + if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r) _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache) def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\ - realizes:Dict[LazyBuffer, None], group:Set[LazyBuffer]) -> Set[LazyBuffer]: - # create a multi output kernel if the LazyBuffers can cleanly group - cache: Set[LazyBuffer] = set() - rc_parents = deque(group) + realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]: + rc_parents, cache = deque(group), set() while rc_parents: if (p:=rc_parents.pop()) in cache: continue cache.add(p) # max one reduceop per kernel - if p.op in ReduceOps: return set() + if p.op in ReduceOps: return {} rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r) # search descendants of the reduceop that can cleanly group - descendants: Set[LazyBuffer] = set() + descendants: Dict[LazyBuffer, None] = {} for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache=set()) - if any(tr in group for tr in descendants): descendants.clear() - return group.union(descendants) + return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): """create a graph for realizing the outputs""" @@ -238,7 +235,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): for r in allbufs: if r.op not in ReduceOps or r in realizes: continue - group: Set[LazyBuffer] = set() + group: Dict[LazyBuffer, None] = {} _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache=set()) # max one reduceop per kernel can_chase = all(tr not in reduce_for_op for tr in group)