add realizes to context [pr] (#7470)

* add realizes set

* add from fuse
This commit is contained in:
qazal
2024-11-02 00:00:30 +02:00
committed by GitHub
parent e3ea7cc4b4
commit 2a1aa55882
2 changed files with 5 additions and 3 deletions

View File

@@ -163,5 +163,6 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
if (dup:=lazybufs_to_realize.get(buf.buffer)) is not None:
raise RuntimeError(f"can't double realize in one schedule, Buffer is realizing both {dup} and {buf}")
lazybufs_to_realize[buf.buffer] = buf
output_groups[reduce_for_op.get(buf, buf)].append(ctx.buf_uops[buf.buffer])
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer])
ctx.realizes[ubuf] = ubuf
return list(output_groups.values()), lazybufs_to_realize, assign_targets

View File

@@ -43,6 +43,7 @@ class ScheduleContext:
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
realizes: Dict[UOp, UOp] = field(default_factory=dict) # this maps a UOps.BUFFER changing in this schedule to its uop
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
@@ -249,12 +250,12 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
# get realizes
store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx)
# split realizes into small graphs
graph_rewrite(big_graph, break_sched, realizes:={(u:=ctx.buf_uops[b]):u for b in lazybufs_to_realize})
graph_rewrite(big_graph, break_sched, ctx.realizes)
assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}
small_graphs: List[Tuple[UOp, ScheduleItemContext]] = []
metadata: List[Set[Metadata]] = []
for stores in store_groups:
sink = UOp.sink(*(realizes[u] for u in stores))
sink = UOp.sink(*(ctx.realizes[u] for u in stores))
metadata.append({mx for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))})
small_graphs.append(full_ast_rewrite(sink, ctx.var_vals, assigned))