diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ef02ed419d..74a44ca210 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -204,6 +204,30 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r) _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache) +def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:Dict[LazyBuffer, Dict[LazyBuffer, None]],\ + realizes:Dict[LazyBuffer, None], group:Set[LazyBuffer]) -> Set[LazyBuffer]: + # create a multi output kernel if the LazyBuffers can cleanly group + cache: Set[LazyBuffer] = set() + rc_parents, rc_children = deque(group), deque(group) + while rc_parents: + if (p:=rc_parents.pop()) in cache: continue + cache.add(p) + # max one reduceop per kernel + if p.op in ReduceOps: return set() + rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r) + # search descendants of the reduceop that can cleanly group + cache.clear() + realized_descendants: Set[LazyBuffer] = set() + while rc_children: + if (c:=rc_children.pop()) in cache: continue + cache.add(c) + if c.op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op: + realized_descendants.clear() + break + if c in realizes and c not in group: realized_descendants.add(c) + rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device) + return group.union(realized_descendants) + def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): """create a graph for realizing the outputs""" # start by just realizing the buffers passed in @@ -232,29 +256,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs forced_realize = r in group if not forced_realize and len(group) > 1: - # create a multi output kernel if the LazyBuffers can cleanly group - cache: Set[LazyBuffer] = set() - rc_parents, rc_children = deque(group), deque(group) - while rc_parents: - if (p:=rc_parents.pop()) in cache: continue - cache.add(p) - # max one reduceop per kernel - if p.op in ReduceOps: - forced_realize = True - break - rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r) - # search descendants of the reduceop that can cleanly group - cache.clear() - realized_descendants: Set[LazyBuffer] = set() - while rc_children and not forced_realize: - if (c:=rc_children.pop()) in cache: continue - cache.add(c) - if c.op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op: - realized_descendants.clear() - break - if c in realizes and c not in group: realized_descendants.add(c) - rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device) - group.update(realized_descendants) + group = _get_isolated_children(r, reduce_for_op, children, realizes, group) # can only fuse assign if no other assign_target is used in the kernel if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group): parents = deque((r, *group)) @@ -263,7 +265,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False continue parents.extend(p.srcs) - if forced_realize: + if forced_realize or not group: tr = r if can_chase: # can chase this down to contiguous children