mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
only lookup buf_uops in fuse.py [pr] (#7641)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from collections import defaultdict, deque
|
||||
from typing import Tuple, List, Dict, DefaultDict
|
||||
from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps
|
||||
from typing import Set, Tuple, List, Dict, DefaultDict
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.ops import GroupOp, MetaOps, Ops, ReduceOps, UOp, UnaryOps
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
@@ -38,12 +39,14 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
|
||||
def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None],
|
||||
double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], ctx) -> List[List[UOp]]:
|
||||
double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], buf_uops:Dict[Buffer, UOp]) -> List[List[UOp]]:
|
||||
"""search the graph for all the LazyBuffers that need to realize"""
|
||||
# get all the realizes from big graph
|
||||
realizes: Dict[LazyBuffer, None] = {}
|
||||
assigns: Set[UOp] = set()
|
||||
for r in allbufs:
|
||||
if ctx.buf_uops[r.buffer] in ubuf_realizes: realizes[r] = None
|
||||
if (ubuf:=buf_uops[r.buffer]) in ubuf_realizes: realizes[r] = None
|
||||
if r.op is Ops.ASSIGN: assigns.add(ubuf)
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
reduce_of_const: List[LazyBuffer] = []
|
||||
@@ -62,7 +65,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
parents = deque((r, *group))
|
||||
while parents and not forced_realize:
|
||||
if (p:=parents.pop().base).is_realized() or p in realizes:
|
||||
if p.is_realized() and p.buffer in ctx.assigns and not any(x.buffer is p.buffer for x in group): forced_realize, can_chase = True, False
|
||||
if p.is_realized() and buf_uops[(b:=p.buffer)] in assigns and not any(x.buffer is b for x in group): forced_realize, can_chase = True, False
|
||||
continue
|
||||
parents.extend(p.srcs)
|
||||
if forced_realize or not group:
|
||||
@@ -92,7 +95,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
top_reduce = reduceop.base.srcs[0].base
|
||||
if len(children[top_reduce]) == 1:
|
||||
del realizes[top_reduce]
|
||||
if (ubuf:=ctx.buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
if (ubuf:=buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
|
||||
for r in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
|
||||
@@ -101,10 +104,10 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group:
|
||||
del realizes[tr]
|
||||
if (ubuf:=ctx.buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
if (ubuf:=buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
|
||||
output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list)
|
||||
for buf in realizes:
|
||||
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer])
|
||||
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=buf_uops[buf.buffer])
|
||||
ubuf_realizes[ubuf] = ubuf
|
||||
return list(output_groups.values())
|
||||
|
||||
@@ -275,7 +275,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
# get realizes
|
||||
realizes: Dict[UOp, UOp] = {}
|
||||
graph_rewrite(big_graph, do_realize, realizes)
|
||||
store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx)
|
||||
store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx.buf_uops)
|
||||
# split realizes into small graphs
|
||||
graph_rewrite(big_graph, break_sched, realizes)
|
||||
sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]
|
||||
|
||||
Reference in New Issue
Block a user