diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 05088d8d9e..a2e560c777 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -40,9 +40,9 @@ class ScheduleItem: @dataclass(frozen=True) class ScheduleContext: - buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) - ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) - var_vals: Dict[Variable, int] = field(default_factory=dict) + buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops + ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata + var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r @@ -237,10 +237,12 @@ break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) == 0: return [], {} - store_groups, lazybufs_to_realize, assigns = get_realizes(outs) + # create the big graph ctx = ScheduleContext() cache: Dict[LazyBuffer, UOp] = {} big_graph = UOp.sink(*(to_uop(x, ctx, cache) for x in outs)) + # get realizes + store_groups, lazybufs_to_realize, assigns = get_realizes(outs) # split realizes into small graphs graph_rewrite(big_graph, break_sched, realizes:={(u:=ctx.buf_uops[b]):u for b in lazybufs_to_realize}) assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}