global metadata try 2 (#7367)

This commit is contained in:
qazal
2024-10-29 14:21:00 +02:00
committed by GitHub
parent 2cfc7b6695
commit d803a9c7c8

View File

@@ -43,12 +43,13 @@ class ScheduleContext:
realizes: Dict[Buffer, LazyBuffer]
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict)
uop_bufs: Dict[UOp, Buffer] = field(default_factory=dict)
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict)
var_vals: Dict[Variable, int] = field(default_factory=dict)
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metadata:Dict[UOp, Metadata], cache:Dict[LazyBuffer, UOp]) -> UOp:
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, outputs, ctx, metadata, cache).view(buf.st)
cache[buf] = ret = to_uop(buf.base, outputs, ctx, cache).view(buf.st)
return ret
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
# consts have VALID + value
@@ -64,7 +65,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metada
if buf.is_realized(): return UOp(UOps.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
if b in ctx.realizes and buf not in outputs: return UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop()))
# otherwise we fuse it like normal
src = tuple(to_uop(x, outputs, ctx, metadata, cache) for x in buf.srcs)
src = tuple(to_uop(x, outputs, ctx, cache) for x in buf.srcs)
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src)
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
@@ -73,7 +74,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metada
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
else: ret = UOp(UOps.ALU, dtype, src, buf.op)
cache[buf] = ret = UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
if buf.metadata is not None: metadata[ubuf] = buf.metadata
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
return ret
# **** AST graph rewrite
@@ -239,11 +240,10 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
for stores in store_groups:
outs = [lazybufs_to_realize[b] for b in stores]
cache: Dict[LazyBuffer, UOp] = {}
metadata: Dict[UOp, Metadata] = {}
to_store = tuple(to_uop(out, outs, ctx, metadata, cache) for out in outs)
to_store = tuple(to_uop(out, outs, ctx, cache) for out in outs)
sink = UOp(UOps.SINK, src=tuple(UOp.store(ctx.buf_uops[x.buffer], ShapeTracker.from_shape(x.shape).to_uop(), u) for x,u in zip(outs,to_store)))
si_ctx = ScheduleItemContext(ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None},
metadata={x:None for x in metadata.values()})
metadata = {mx:None for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))}
si_ctx = ScheduleItemContext(ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}, metadata=metadata)
small_graphs.append((full_ast_rewrite(sink, si_ctx), si_ctx))
# do BFS