mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-03 19:25:06 -05:00
add realizes to context [pr] (#7470)
* add realizes set * add from fuse
This commit is contained in:
@@ -163,5 +163,6 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
|
||||
if (dup:=lazybufs_to_realize.get(buf.buffer)) is not None:
|
||||
raise RuntimeError(f"can't double realize in one schedule, Buffer is realizing both {dup} and {buf}")
|
||||
lazybufs_to_realize[buf.buffer] = buf
|
||||
output_groups[reduce_for_op.get(buf, buf)].append(ctx.buf_uops[buf.buffer])
|
||||
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer])
|
||||
ctx.realizes[ubuf] = ubuf
|
||||
return list(output_groups.values()), lazybufs_to_realize, assign_targets
|
||||
|
||||
@@ -43,6 +43,7 @@ class ScheduleContext:
|
||||
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops
|
||||
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
|
||||
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
|
||||
realizes: Dict[UOp, UOp] = field(default_factory=dict) # this maps a UOps.BUFFER changing in this schedule to its uop
|
||||
|
||||
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
@@ -249,12 +250,12 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
# get realizes
|
||||
store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx)
|
||||
# split realizes into small graphs
|
||||
graph_rewrite(big_graph, break_sched, realizes:={(u:=ctx.buf_uops[b]):u for b in lazybufs_to_realize})
|
||||
graph_rewrite(big_graph, break_sched, ctx.realizes)
|
||||
assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}
|
||||
small_graphs: List[Tuple[UOp, ScheduleItemContext]] = []
|
||||
metadata: List[Set[Metadata]] = []
|
||||
for stores in store_groups:
|
||||
sink = UOp.sink(*(realizes[u] for u in stores))
|
||||
sink = UOp.sink(*(ctx.realizes[u] for u in stores))
|
||||
metadata.append({mx for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))})
|
||||
small_graphs.append(full_ast_rewrite(sink, ctx.var_vals, assigned))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user