make fusion deterministic (#5684)

* make fusion deterministic

* not this one yet

* line saving
This commit is contained in:
qazal
2024-07-24 23:37:31 +08:00
committed by GitHub
parent 2ea54176e2
commit 365e7afd4d

View File

@@ -184,38 +184,35 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer], cache:Set):
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], cache:Set):
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
if (tr, st) in cache: return
cache.add((tr, st))
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r)
return group.add(tr)
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
if tr_next.op in ReduceOps: return group.add(r)
if tr_next.op in ReduceOps: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
realizes:Dict[LazyBuffer, None], group:Set[LazyBuffer]) -> Set[LazyBuffer]:
# create a multi output kernel if the LazyBuffers can cleanly group
cache: Set[LazyBuffer] = set()
rc_parents = deque(group)
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]:
rc_parents, cache = deque(group), set()
while rc_parents:
if (p:=rc_parents.pop()) in cache: continue
cache.add(p)
# max one reduceop per kernel
if p.op in ReduceOps: return set()
if p.op in ReduceOps: return {}
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
# search descendants of the reduceop that can cleanly group
descendants: Set[LazyBuffer] = set()
descendants: Dict[LazyBuffer, None] = {}
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache=set())
if any(tr in group for tr in descendants): descendants.clear()
return group.union(descendants)
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
"""create a graph for realizing the outputs"""
@@ -238,7 +235,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
for r in allbufs:
if r.op not in ReduceOps or r in realizes: continue
group: Set[LazyBuffer] = set()
group: Dict[LazyBuffer, None] = {}
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache=set())
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)