mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
don't need outputs in fuse.py [pr] (#7639)
This commit is contained in:
@@ -37,7 +37,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
|
||||
def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None],
|
||||
def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None],
|
||||
double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], ctx) -> List[List[UOp]]:
|
||||
"""search the graph for all the LazyBuffers that need to realize"""
|
||||
# get all the realizes from big graph
|
||||
@@ -96,7 +96,7 @@ def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[La
|
||||
|
||||
for r in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
|
||||
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
|
||||
if any(tr.forced_realize for tr in group): continue
|
||||
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group:
|
||||
|
||||
@@ -275,7 +275,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
# get realizes
|
||||
realizes: Dict[UOp, UOp] = {}
|
||||
graph_rewrite(big_graph, do_realize, realizes)
|
||||
store_groups = get_realizes(outs, children, allbufs, double_reduces, realizes, ctx)
|
||||
store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx)
|
||||
# split realizes into small graphs
|
||||
graph_rewrite(big_graph, break_sched, realizes)
|
||||
sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]
|
||||
|
||||
Reference in New Issue
Block a user