proposal: forced_realize is Ops.CONTIGUOUS (same buffer) (#7533)

* forced_realize is Ops.CONTIGUOUS

* Revert "forced_realize is Ops.CONTIGUOUS"

This reverts commit d8bcd5d301.

* forced_realize is Ops.CONTIGUOUS (same buffer)

* fold contig is lazy
This commit is contained in:
qazal
2024-11-04 18:58:43 +02:00
committed by GitHub
parent 76cc59940d
commit 3163bbc48a

View File

@@ -73,9 +73,9 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
elif buf.op is Ops.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
elif buf.op in GroupOp.Meta: ret = UOp(buf.op, buf.dtype, (ubuf, *src), buf.arg)
else: ret = UOp(cast(Ops, buf.op), dtype, src)
if buf.forced_realize: ret = UOp(Ops.CONTIGUOUS, dtype, (ret,))
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
if buf.forced_realize: ctx.realizes[ubuf] = ubuf
return ret
# **** AST graph rewrite
@@ -187,7 +187,6 @@ def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
to_si = PatternMatcher([
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
(UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,x: x),
])
@@ -195,6 +194,7 @@ to_si = PatternMatcher([
lazy = PatternMatcher([
(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), lambda ctx,b,v: v),
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
])
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])