From fbd7d16e9e61a5fde31742022cfef34f2369b18e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 6 Nov 2024 18:24:07 +0200 Subject: [PATCH] create realizes later [pr] (#7571) --- tinygrad/engine/fuse.py | 10 +++++----- tinygrad/engine/schedule.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index f48634f69a..3a701d1caa 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -38,12 +38,12 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None], - double_reduces:Dict[LazyBuffer, None], ctx) -> List[List[UOp]]: + double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], ctx) -> List[List[UOp]]: """search the graph for all the LazyBuffers that need to realize""" # get all the realizes from big graph realizes: Dict[LazyBuffer, None] = {} for r in allbufs: - if ctx.buf_uops[r.buffer] in ctx.realizes: realizes[r] = None + if ctx.buf_uops[r.buffer] in ubuf_realizes: realizes[r] = None # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {} reduce_of_const: List[LazyBuffer] = [] @@ -92,7 +92,7 @@ def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[La top_reduce = reduceop.base.srcs[0].base if len(children[top_reduce]) == 1: del realizes[top_reduce] - if (ubuf:=ctx.buf_uops[top_reduce.buffer]) in ctx.realizes: del ctx.realizes[ubuf] + if (ubuf:=ctx.buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] for r in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is r} @@ -101,10 +101,10 @@ def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[La if len(kernel_children) == 0: continue for tr in group: del realizes[tr] - if (ubuf:=ctx.buf_uops[tr.buffer]) in ctx.realizes: del ctx.realizes[ubuf] + if (ubuf:=ctx.buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list) for buf in realizes: output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer]) - ctx.realizes[ubuf] = ubuf + ubuf_realizes[ubuf] = ubuf return list(output_groups.values()) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 56fa294128..42891bdfbd 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -42,7 +42,6 @@ class ScheduleContext: buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value - realizes: Dict[UOp, UOp] = field(default_factory=dict) # this maps a UOps.BUFFER changing in this schedule to its uop assigns: Set[Buffer] = field(default_factory=set) # this holds all the UOps.BUFFERs we ASSIGN to in this schedule lazybufs: Dict[Buffer, LazyBuffer] = field(default_factory=dict) # this is a lookup for the LazyBuffers we need to mark as realized @@ -280,11 +279,12 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] double_reduces: Dict[LazyBuffer, None] = {} big_graph = UOp.sink(*(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in outs)) # get realizes - graph_rewrite(big_graph, do_realize, ctx.realizes) - store_groups = get_realizes(outs, children, allbufs, double_reduces, ctx) + realizes: Dict[UOp, UOp] = {} + graph_rewrite(big_graph, do_realize, realizes) + store_groups = get_realizes(outs, children, allbufs, double_reduces, realizes, ctx) # split realizes into small graphs - graph_rewrite(big_graph, break_sched, ctx.realizes) - sinks = [UOp.sink(*(ctx.realizes[u] for u in stores)) for stores in store_groups] + graph_rewrite(big_graph, break_sched, realizes) + sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups] # preschedule all realizes bufs = list(ctx.buf_uops) prescheduled: List[ScheduleItem] = []