only lookup buf_uops in fuse.py [pr] (#7641)

This commit is contained in:
qazal
2024-11-11 19:14:30 +02:00
committed by GitHub
parent 08b9f055f2
commit 0b66a0d688
2 changed files with 12 additions and 9 deletions

View File

@@ -1,6 +1,7 @@
from collections import defaultdict, deque
from typing import Tuple, List, Dict, DefaultDict
from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps
from typing import Set, Tuple, List, Dict, DefaultDict
from tinygrad.device import Buffer
from tinygrad.ops import GroupOp, MetaOps, Ops, ReduceOps, UOp, UnaryOps
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.lazy import LazyBuffer
@@ -38,12 +39,14 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None],
double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], ctx) -> List[List[UOp]]:
double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], buf_uops:Dict[Buffer, UOp]) -> List[List[UOp]]:
"""search the graph for all the LazyBuffers that need to realize"""
# get all the realizes from big graph
realizes: Dict[LazyBuffer, None] = {}
assigns: Set[UOp] = set()
for r in allbufs:
if ctx.buf_uops[r.buffer] in ubuf_realizes: realizes[r] = None
if (ubuf:=buf_uops[r.buffer]) in ubuf_realizes: realizes[r] = None
if r.op is Ops.ASSIGN: assigns.add(ubuf)
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
reduce_of_const: List[LazyBuffer] = []
@@ -62,7 +65,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
parents = deque((r, *group))
while parents and not forced_realize:
if (p:=parents.pop().base).is_realized() or p in realizes:
if p.is_realized() and p.buffer in ctx.assigns and not any(x.buffer is p.buffer for x in group): forced_realize, can_chase = True, False
if p.is_realized() and buf_uops[(b:=p.buffer)] in assigns and not any(x.buffer is b for x in group): forced_realize, can_chase = True, False
continue
parents.extend(p.srcs)
if forced_realize or not group:
@@ -92,7 +95,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
top_reduce = reduceop.base.srcs[0].base
if len(children[top_reduce]) == 1:
del realizes[top_reduce]
if (ubuf:=ctx.buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
if (ubuf:=buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
for r in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
@@ -101,10 +104,10 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
if len(kernel_children) == 0: continue
for tr in group:
del realizes[tr]
if (ubuf:=ctx.buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
if (ubuf:=buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list)
for buf in realizes:
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer])
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=buf_uops[buf.buffer])
ubuf_realizes[ubuf] = ubuf
return list(output_groups.values())

View File

@@ -275,7 +275,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
# get realizes
realizes: Dict[UOp, UOp] = {}
graph_rewrite(big_graph, do_realize, realizes)
store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx)
store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx.buf_uops)
# split realizes into small graphs
graph_rewrite(big_graph, break_sched, realizes)
sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]