mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-10 06:35:33 -05:00
scheduling infra for isolated dags (#5679)
* refactor to get_isolated_children * move assign
This commit is contained in:
@@ -204,6 +204,30 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
|
||||
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
|
||||
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
|
||||
|
||||
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:Dict[LazyBuffer, Dict[LazyBuffer, None]],\
|
||||
realizes:Dict[LazyBuffer, None], group:Set[LazyBuffer]) -> Set[LazyBuffer]:
|
||||
# create a multi output kernel if the LazyBuffers can cleanly group
|
||||
cache: Set[LazyBuffer] = set()
|
||||
rc_parents, rc_children = deque(group), deque(group)
|
||||
while rc_parents:
|
||||
if (p:=rc_parents.pop()) in cache: continue
|
||||
cache.add(p)
|
||||
# max one reduceop per kernel
|
||||
if p.op in ReduceOps: return set()
|
||||
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
cache.clear()
|
||||
realized_descendants: Set[LazyBuffer] = set()
|
||||
while rc_children:
|
||||
if (c:=rc_children.pop()) in cache: continue
|
||||
cache.add(c)
|
||||
if c.op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
|
||||
realized_descendants.clear()
|
||||
break
|
||||
if c in realizes and c not in group: realized_descendants.add(c)
|
||||
rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
|
||||
return group.union(realized_descendants)
|
||||
|
||||
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
||||
"""create a graph for realizing the outputs"""
|
||||
# start by just realizing the buffers passed in
|
||||
@@ -232,29 +256,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
||||
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
||||
forced_realize = r in group
|
||||
if not forced_realize and len(group) > 1:
|
||||
# create a multi output kernel if the LazyBuffers can cleanly group
|
||||
cache: Set[LazyBuffer] = set()
|
||||
rc_parents, rc_children = deque(group), deque(group)
|
||||
while rc_parents:
|
||||
if (p:=rc_parents.pop()) in cache: continue
|
||||
cache.add(p)
|
||||
# max one reduceop per kernel
|
||||
if p.op in ReduceOps:
|
||||
forced_realize = True
|
||||
break
|
||||
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
cache.clear()
|
||||
realized_descendants: Set[LazyBuffer] = set()
|
||||
while rc_children and not forced_realize:
|
||||
if (c:=rc_children.pop()) in cache: continue
|
||||
cache.add(c)
|
||||
if c.op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
|
||||
realized_descendants.clear()
|
||||
break
|
||||
if c in realizes and c not in group: realized_descendants.add(c)
|
||||
rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
|
||||
group.update(realized_descendants)
|
||||
group = _get_isolated_children(r, reduce_for_op, children, realizes, group)
|
||||
# can only fuse assign if no other assign_target is used in the kernel
|
||||
if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group):
|
||||
parents = deque((r, *group))
|
||||
@@ -263,7 +265,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
||||
if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
|
||||
continue
|
||||
parents.extend(p.srcs)
|
||||
if forced_realize:
|
||||
if forced_realize or not group:
|
||||
tr = r
|
||||
if can_chase:
|
||||
# can chase this down to contiguous children
|
||||
|
||||
Reference in New Issue
Block a user