diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 10afc2c08f..b564771d54 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -40,15 +40,14 @@ class ScheduleItem: @dataclass(frozen=True) class ScheduleContext: - realizes: Dict[Buffer, LazyBuffer] buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) var_vals: Dict[Variable, int] = field(default_factory=dict) -def to_uop(buf:LazyBuffer, realizes:Dict[UOp, UOp], ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp: +def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r if buf is not buf.base: - cache[buf] = ret = to_uop(buf.base, realizes, ctx, cache).view(buf.st) + cache[buf] = ret = to_uop(buf.base, ctx, cache).view(buf.st) return ret dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype # consts are always fused and generated @@ -56,14 +55,11 @@ def to_uop(buf:LazyBuffer, realizes:Dict[UOp, UOp], ctx:ScheduleContext, cache:D if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()]) return UOp(UOps.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0)) # everything else has BUFFER - if (b:=buf.buffer) not in ctx.buf_uops: - ctx.buf_uops[b] = ubuf = UOp(UOps.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype))) - if b in ctx.realizes: realizes[ubuf] = ubuf - else: ubuf = ctx.buf_uops[b] + ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(UOps.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype)))) # if the buffer is already realized we just load it if buf.is_realized(): return UOp(UOps.PRELOAD, dtype, (ubuf, buf.st.to_uop())) # everything else needs sources - src = tuple(to_uop(x, realizes, ctx, cache) for x in buf.srcs) + src = tuple(to_uop(x, ctx, cache) for x in buf.srcs) if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg) elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src) elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (ubuf, src[1]), buf.arg) @@ -234,12 +230,11 @@ break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: store_groups, lazybufs_to_realize, assigns = get_realizes(outs) if len(store_groups) == 0: return [], {} # nothing to schedule - ctx = ScheduleContext(lazybufs_to_realize) - realizes: Dict[UOp, UOp] = {} + ctx = ScheduleContext() cache: Dict[LazyBuffer, UOp] = {} - big_graph = UOp.sink(*(to_uop(x.base, realizes, ctx, cache) for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) + big_graph = UOp.sink(*(to_uop(x.base, ctx, cache) for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) # split realizes into small graphs - graph_rewrite(big_graph, break_sched, realizes) + graph_rewrite(big_graph, break_sched, realizes:={(u:=ctx.buf_uops[b]):u for b in lazybufs_to_realize}) assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None} small_graphs: List[Tuple[UOp, ScheduleItemContext]] = [] metadata: List[Set[Metadata]] = []