mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 08:48:15 -05:00
global metadata try 2 (#7367)
This commit is contained in:
@@ -43,12 +43,13 @@ class ScheduleContext:
|
||||
realizes: Dict[Buffer, LazyBuffer]
|
||||
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict)
|
||||
uop_bufs: Dict[UOp, Buffer] = field(default_factory=dict)
|
||||
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict)
|
||||
var_vals: Dict[Variable, int] = field(default_factory=dict)
|
||||
|
||||
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metadata:Dict[UOp, Metadata], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, outputs, ctx, metadata, cache).view(buf.st)
|
||||
cache[buf] = ret = to_uop(buf.base, outputs, ctx, cache).view(buf.st)
|
||||
return ret
|
||||
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
|
||||
# consts have VALID + value
|
||||
@@ -64,7 +65,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metada
|
||||
if buf.is_realized(): return UOp(UOps.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
if b in ctx.realizes and buf not in outputs: return UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
# otherwise we fuse it like normal
|
||||
src = tuple(to_uop(x, outputs, ctx, metadata, cache) for x in buf.srcs)
|
||||
src = tuple(to_uop(x, outputs, ctx, cache) for x in buf.srcs)
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
@@ -73,7 +74,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metada
|
||||
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
|
||||
else: ret = UOp(UOps.ALU, dtype, src, buf.op)
|
||||
cache[buf] = ret = UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
|
||||
if buf.metadata is not None: metadata[ubuf] = buf.metadata
|
||||
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
@@ -239,11 +240,10 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
for stores in store_groups:
|
||||
outs = [lazybufs_to_realize[b] for b in stores]
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
metadata: Dict[UOp, Metadata] = {}
|
||||
to_store = tuple(to_uop(out, outs, ctx, metadata, cache) for out in outs)
|
||||
to_store = tuple(to_uop(out, outs, ctx, cache) for out in outs)
|
||||
sink = UOp(UOps.SINK, src=tuple(UOp.store(ctx.buf_uops[x.buffer], ShapeTracker.from_shape(x.shape).to_uop(), u) for x,u in zip(outs,to_store)))
|
||||
si_ctx = ScheduleItemContext(ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None},
|
||||
metadata={x:None for x in metadata.values()})
|
||||
metadata = {mx:None for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))}
|
||||
si_ctx = ScheduleItemContext(ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}, metadata=metadata)
|
||||
small_graphs.append((full_ast_rewrite(sink, si_ctx), si_ctx))
|
||||
|
||||
# do BFS
|
||||
|
||||
Reference in New Issue
Block a user