From d803a9c7c8327e602eafeb4195b65ef3397b520e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:21:00 +0200 Subject: [PATCH] global metadata try 2 (#7367) --- tinygrad/engine/schedule.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 969bfbc0e7..9901bbcafc 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -43,12 +43,13 @@ class ScheduleContext: realizes: Dict[Buffer, LazyBuffer] buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) uop_bufs: Dict[UOp, Buffer] = field(default_factory=dict) + ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) var_vals: Dict[Variable, int] = field(default_factory=dict) -def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metadata:Dict[UOp, Metadata], cache:Dict[LazyBuffer, UOp]) -> UOp: +def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r if buf is not buf.base: - cache[buf] = ret = to_uop(buf.base, outputs, ctx, metadata, cache).view(buf.st) + cache[buf] = ret = to_uop(buf.base, outputs, ctx, cache).view(buf.st) return ret dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype # consts have VALID + value @@ -64,7 +65,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metada if buf.is_realized(): return UOp(UOps.PRELOAD, dtype, (ubuf, buf.st.to_uop())) if b in ctx.realizes and buf not in outputs: return UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop())) # otherwise we fuse it like normal - src = tuple(to_uop(x, outputs, ctx, metadata, cache) for x in buf.srcs) + src = tuple(to_uop(x, outputs, ctx, cache) for x in buf.srcs) if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg) elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src) elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (ubuf, src[1]), buf.arg) @@ -73,7 +74,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metada elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src) else: ret = UOp(UOps.ALU, dtype, src, buf.op) cache[buf] = ret = UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret))) - if buf.metadata is not None: metadata[ubuf] = buf.metadata + if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata return ret # **** AST graph rewrite @@ -239,11 +240,10 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] for stores in store_groups: outs = [lazybufs_to_realize[b] for b in stores] cache: Dict[LazyBuffer, UOp] = {} - metadata: Dict[UOp, Metadata] = {} - to_store = tuple(to_uop(out, outs, ctx, metadata, cache) for out in outs) + to_store = tuple(to_uop(out, outs, ctx, cache) for out in outs) sink = UOp(UOps.SINK, src=tuple(UOp.store(ctx.buf_uops[x.buffer], ShapeTracker.from_shape(x.shape).to_uop(), u) for x,u in zip(outs,to_store))) - si_ctx = ScheduleItemContext(ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}, - metadata={x:None for x in metadata.values()}) + metadata = {mx:None for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))} + si_ctx = ScheduleItemContext(ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}, metadata=metadata) small_graphs.append((full_ast_rewrite(sink, si_ctx), si_ctx)) # do BFS