mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
map groupable children (#5710)
* map groupable children * remove setitem
This commit is contained in:
@@ -186,20 +186,21 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
||||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||
|
||||
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], cache:Set):
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, bool], cache:Set):
|
||||
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
|
||||
if (tr, st) in cache: return
|
||||
cache.add((tr, st))
|
||||
if tr in realizes and tr is not r:
|
||||
# can only fuse contiguous
|
||||
# max one reduceop per kernel
|
||||
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.setdefault(r)
|
||||
return group.setdefault(tr)
|
||||
group[tr] = st.contiguous and st.size == r.st.size and tr not in reduce_for_op
|
||||
return
|
||||
for tr_next in children[tr]:
|
||||
# max one reduceop per kernel
|
||||
if tr_next.op in ReduceOps: return group.setdefault(r)
|
||||
# can only fuse contiguous
|
||||
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
|
||||
if tr_next.op in ReduceOps or len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1:
|
||||
group[tr_next] = False
|
||||
return
|
||||
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
|
||||
|
||||
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
|
||||
@@ -212,9 +213,10 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
if p.op in ReduceOps: return {}
|
||||
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
descendants: Dict[LazyBuffer, None] = {}
|
||||
descendants: Dict[LazyBuffer, bool] = {}
|
||||
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache=set())
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
descendants_to_group = {tr:None for tr,can_group in descendants.items() if can_group}
|
||||
return merge_dicts([group, descendants_to_group if len(descendants_to_group) == len(descendants) else {}])
|
||||
|
||||
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
||||
"""create a graph for realizing the outputs"""
|
||||
@@ -237,13 +239,13 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
||||
for r in allbufs:
|
||||
if r.op not in ReduceOps or r in realizes: continue
|
||||
|
||||
group: Dict[LazyBuffer, None] = {}
|
||||
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache=set())
|
||||
reduceop_children: Dict[LazyBuffer, bool] = {}
|
||||
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, reduceop_children, cache=set())
|
||||
# max one reduceop per kernel
|
||||
can_chase = all(tr not in reduce_for_op for tr in group)
|
||||
can_chase = all(tr not in reduce_for_op for tr in reduceop_children)
|
||||
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
||||
forced_realize = r in group
|
||||
if not forced_realize and len(group) > 1:
|
||||
forced_realize = any(not can_group for can_group in reduceop_children.values())
|
||||
if len(group:={tr:None for tr,can_group in reduceop_children.items() if can_group}) > 1 and not forced_realize:
|
||||
group = _get_isolated_children(r, reduce_for_op, children, realizes, group)
|
||||
# can only fuse assign if no other assign_target is used in the kernel
|
||||
if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group):
|
||||
|
||||
Reference in New Issue
Block a user