diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index d8b9e73025..95b29bfa93 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -1,6 +1,7 @@ from collections import defaultdict, deque -from typing import Tuple, List, Dict, DefaultDict -from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps +from typing import Set, Tuple, List, Dict, DefaultDict +from tinygrad.device import Buffer +from tinygrad.ops import GroupOp, MetaOps, Ops, ReduceOps, UOp, UnaryOps from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.engine.lazy import LazyBuffer @@ -38,12 +39,14 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None], - double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], ctx) -> List[List[UOp]]: + double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], buf_uops:Dict[Buffer, UOp]) -> List[List[UOp]]: """search the graph for all the LazyBuffers that need to realize""" # get all the realizes from big graph realizes: Dict[LazyBuffer, None] = {} + assigns: Set[UOp] = set() for r in allbufs: - if ctx.buf_uops[r.buffer] in ubuf_realizes: realizes[r] = None + if (ubuf:=buf_uops[r.buffer]) in ubuf_realizes: realizes[r] = None + if r.op is Ops.ASSIGN: assigns.add(ubuf) # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {} reduce_of_const: List[LazyBuffer] = [] @@ -62,7 +65,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu parents = deque((r, *group)) while parents and not forced_realize: if (p:=parents.pop().base).is_realized() or p in realizes: - if p.is_realized() and p.buffer in ctx.assigns and not any(x.buffer is p.buffer for x in group): forced_realize, can_chase = True, False + if p.is_realized() and buf_uops[(b:=p.buffer)] in assigns and not any(x.buffer is b for x in group): forced_realize, can_chase = True, False continue parents.extend(p.srcs) if forced_realize or not group: @@ -92,7 +95,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu top_reduce = reduceop.base.srcs[0].base if len(children[top_reduce]) == 1: del realizes[top_reduce] - if (ubuf:=ctx.buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] + if (ubuf:=buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] for r in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is r} @@ -101,10 +104,10 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu if len(kernel_children) == 0: continue for tr in group: del realizes[tr] - if (ubuf:=ctx.buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] + if (ubuf:=buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list) for buf in realizes: - output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer]) + output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=buf_uops[buf.buffer]) ubuf_realizes[ubuf] = ubuf return list(output_groups.values()) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 3d3e155e3b..54e6d43bca 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -275,7 +275,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] # get realizes realizes: Dict[UOp, UOp] = {} graph_rewrite(big_graph, do_realize, realizes) - store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx) + store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx.buf_uops) # split realizes into small graphs graph_rewrite(big_graph, break_sched, realizes) sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]