From 3163bbc48aa00ca3cff80ee1c5a933396c6d260b Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 4 Nov 2024 18:58:43 +0200 Subject: [PATCH] proposal: forced_realize is Ops.CONTIGUOUS (same buffer) (#7533) * forced_realize is Ops.CONTIGUOUS * Revert "forced_realize is Ops.CONTIGUOUS" This reverts commit d8bcd5d301e44c7d775b6ec1aaf8a7eefb0611ec. * forced_realize is Ops.CONTIGUOUS (same buffer) * fold contig is lazy --- tinygrad/engine/schedule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 0de75c30df..04c4ea50c1 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -73,9 +73,9 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> elif buf.op is Ops.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg) elif buf.op in GroupOp.Meta: ret = UOp(buf.op, buf.dtype, (ubuf, *src), buf.arg) else: ret = UOp(cast(Ops, buf.op), dtype, src) + if buf.forced_realize: ret = UOp(Ops.CONTIGUOUS, dtype, (ret,)) cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret))) if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata - if buf.forced_realize: ctx.realizes[ubuf] = ubuf return ret # **** AST graph rewrite @@ -187,7 +187,6 @@ def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp: to_si = PatternMatcher([ (UPat(Ops.VIEW, name="x"), _append_st_vars), (UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload), - (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x), (UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,x: x), ]) @@ -195,6 +194,7 @@ to_si = PatternMatcher([ lazy = PatternMatcher([ (UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), lambda ctx,b,v: v), + (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x), ]) multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])