diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b4be3e8bbe..f2b9f8f29a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -51,6 +51,7 @@ def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r + # view is passthrough if buf is not buf.base: cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st) return ret @@ -64,25 +65,29 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache # hack the underlying buffer too buf.buffer.dtype = dtype buf.buffer.options = None - if buf.is_realized: - ubuf = UOp.new_buffer(buf.device, buf.size, dtype) - buffers[ubuf] = buf.buffer - op = None - elif buf.op is Ops.ASSIGN: - target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs] - ctx.assigns.add(ubuf:=target.base.buf_uop) - op = UOp(Ops.ASSIGN, dtype.base, (ubuf, new_val), buf.arg) - else: - ubuf = UOp.new_buffer(buf.device, buf.size, dtype) - buffers[ubuf] = buf.buffer - op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg) - cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st) + # base is a VIEW of (BUFFER, (optional) op) + match buf.is_realized: + case True: + buf_uop = UOp.new_buffer(buf.device, buf.size, dtype) + op = None + case False: + src = tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs) + match buf.op: + # ASSIGN uses the target buffer + case Ops.ASSIGN: buf_uop = src[0].base.buf_uop + # otherwise we create a new buffer + case _: buf_uop = UOp.new_buffer(buf.device, buf.size, dtype) + op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, src, buf.arg) + cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.contiguous() if buf.forced_realize else op), buf.st) + # keep track of ops outside the big graph + buffers[buf_uop] = buf.buffer if op is not None: buf.buffer.ref(1) - ctx.lazybufs[ubuf] = buf - ctx.allbufs[ubuf] = ret + ctx.lazybufs[buf_uop] = buf + ctx.allbufs[buf_uop] = ret + if op.op is Ops.ASSIGN: ctx.assigns.add(buf_uop) for x in op.src: - if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None + if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None return ret # **** AST graph rewrite @@ -178,7 +183,9 @@ to_si = PatternMatcher([ # ** fusion lazy = PatternMatcher([ + # gather the metadata for this kernel (UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.metadata.add(m) if (m:=ctx.ops_metadata.get(x)) is not None else None), + # don't need contiguous anymore (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x), ]) @@ -375,6 +382,8 @@ do_realize = PatternMatcher([ (UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast), # realize before COPY or BUFFER_VIEW (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), + # ASSIGN only needs the buffer + (UPat(Ops.ASSIGN, src=(UPat(Ops.VIEW, name="dest"), UPat.var("src")), name="x"), lambda ctx,dest,src,x: x.replace(src=(dest.base.buf_uop, src))), ]) # ** this breaks down realized ops into STOREs and rewrites the ops to LOADs