From 5e2e5b2cdc6bf393ee197b5bcb829a20577d5db6 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 30 Oct 2024 07:58:09 +0200 Subject: [PATCH] finally big graph (#7293) * real big graph * extra lines --- tinygrad/engine/schedule.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 05f3db7ada..575bb625a7 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -46,13 +46,13 @@ class ScheduleContext: ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) var_vals: Dict[Variable, int] = field(default_factory=dict) -def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp: +def to_uop(buf:LazyBuffer, realizes:Dict[UOp, UOp], ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r if buf is not buf.base: - cache[buf] = ret = to_uop(buf.base, outputs, ctx, cache).view(buf.st) + cache[buf] = ret = to_uop(buf.base, realizes, ctx, cache).view(buf.st) return ret dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype - # consts have VALID + value + # consts are always fused and generated if buf.op is MetaOps.CONST: if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()]) return UOp(UOps.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0)) @@ -60,12 +60,12 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, cache: if (b:=buf.buffer) not in ctx.buf_uops: ctx.buf_uops[b] = ubuf = UOp(UOps.BUFFER, buf.buffer.dtype.ptr(), (), (len(ctx.buf_uops), (buf.buffer.device, buf.buffer.size, buf.buffer.dtype))) ctx.uop_bufs[ubuf] = b + if b in ctx.realizes: realizes[ubuf] = ubuf else: ubuf = ctx.buf_uops[b] - # if it's not fused it's a LOAD + # if the buffer is already realized we just load it if buf.is_realized(): return UOp(UOps.PRELOAD, dtype, (ubuf, buf.st.to_uop())) - if b in ctx.realizes and buf not in outputs: return UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop())) - # otherwise we fuse it like normal - src = tuple(to_uop(x, outputs, ctx, cache) for x in buf.srcs) + # everything else needs sources + src = tuple(to_uop(x, realizes, ctx, cache) for x in buf.srcs) if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg) elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src) elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (ubuf, src[1]), buf.arg) @@ -233,17 +233,25 @@ if getenv("RUN_PROCESS_REPLAY"): # **** Schedule creation and BFS toposort +def _add_realize(realizes:Dict[UOp, UOp], b:UOp, store:UOp, load:UOp) -> Optional[UOp]: + if b not in realizes: return None + realizes[b] = store + return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop())) +break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat(), name="store"), name="load"), _add_realize),]) + @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: store_groups, lazybufs_to_realize, assigns = get_realizes(outs) + if len(store_groups) == 0: return [], {} # nothing to schedule ctx = ScheduleContext(lazybufs_to_realize) + realizes: Dict[UOp, UOp] = {} + cache: Dict[LazyBuffer, UOp] = {} + big_graph = UOp.sink(*(to_uop(x.base, realizes, ctx, cache) for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) # split realizes into small graphs - small_graphs: List[Tuple[UOp, ScheduleItemContext]] = [] - for stores in store_groups: - outs = [lazybufs_to_realize[b] for b in stores] - cache: Dict[LazyBuffer, UOp] = {} - small_graphs.append(full_ast_rewrite(UOp.sink(*(to_uop(out, outs, ctx, cache).src[2] for out in outs)), - ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}, ctx.ubuf_metadata)) + graph_rewrite(big_graph, break_sched, realizes) + assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None} + small_graphs = [full_ast_rewrite(UOp.sink(*(realizes[ctx.buf_uops[b]] for b in stores)), + ctx.var_vals, assigned, ctx.ubuf_metadata) for stores in store_groups] # do BFS prescheduled = [ScheduleItem(u, tuple(b for u in c.bufs if (b:=ctx.uop_bufs[u]).size != 0),