From 2a1aa55882b0a3f7810f1e8da61aee5061313994 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 2 Nov 2024 00:00:30 +0200 Subject: [PATCH] add realizes to context [pr] (#7470) * add realizes set * add from fuse --- tinygrad/engine/fuse.py | 3 ++- tinygrad/engine/schedule.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index 8c946b95f1..b9567f944f 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -163,5 +163,6 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff if (dup:=lazybufs_to_realize.get(buf.buffer)) is not None: raise RuntimeError(f"can't double realize in one schedule, Buffer is realizing both {dup} and {buf}") lazybufs_to_realize[buf.buffer] = buf - output_groups[reduce_for_op.get(buf, buf)].append(ctx.buf_uops[buf.buffer]) + output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer]) + ctx.realizes[ubuf] = ubuf return list(output_groups.values()), lazybufs_to_realize, assign_targets diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index a01ddba0c9..9b23772be0 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -43,6 +43,7 @@ class ScheduleContext: buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value + realizes: Dict[UOp, UOp] = field(default_factory=dict) # this maps a UOps.BUFFER changing in this schedule to its uop def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r @@ -249,12 +250,12 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] # get realizes store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx) # split realizes into small graphs - graph_rewrite(big_graph, break_sched, realizes:={(u:=ctx.buf_uops[b]):u for b in lazybufs_to_realize}) + graph_rewrite(big_graph, break_sched, ctx.realizes) assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None} small_graphs: List[Tuple[UOp, ScheduleItemContext]] = [] metadata: List[Set[Metadata]] = [] for stores in store_groups: - sink = UOp.sink(*(realizes[u] for u in stores)) + sink = UOp.sink(*(ctx.realizes[u] for u in stores)) metadata.append({mx for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))}) small_graphs.append(full_ast_rewrite(sink, ctx.var_vals, assigned))