track seen graphs in recursive group (#5301)

* track seen

* maybe never add realized

* ahh it needs to track sts

* delete extra check

* cache typings

* minor cleanup
This commit is contained in:
qazal
2024-07-06 12:39:31 +03:00
committed by GitHub
parent d813617742
commit 11dfb19b20

View File

@@ -142,7 +142,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
realizes[buf.srcs[0].base] = None
if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None
for x in buf.srcs:
children[x.base][buf] = None
if x.base.realized is None: children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children)
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
@@ -152,20 +152,21 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer], cache:Set):
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
if (tr, st) in cache: return
cache.add((tr, st))
if tr in realizes:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r)
return group.add(tr)
for tr_next in children[tr]:
if tr_next.realized is None:
# max one reduceop per kernel
if tr_next.op in ReduceOps: return group.add(r)
# can only fuse contiguous
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group)
# max one reduceop per kernel
if tr_next.op in ReduceOps: return group.add(r)
# can only fuse contiguous
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
"""create a graph for realizing the outputs"""
@@ -188,7 +189,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
if r.op not in ReduceOps or r in realizes: continue
group: Set[LazyBuffer] = set()
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group)
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache=set())
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs