diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 4ef2e17936..60573e7c68 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -68,25 +68,26 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[U if buf.op is Ops.CONST: if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()]) return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0) - # everything else has BUFFER - ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp.new_buffer(b.device, b.size, b.dtype, num=len(ctx.buf_uops))) - # if the buffer is already realized we just load it - if buf.is_realized(): return UOp(Ops.VIEW, dtype, (ubuf,), buf.st) - # everything else needs sources - src = tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs) - if buf.op is Ops.ASSIGN: - ctx.assigns.add(ubuf) - ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg) - else: ret = UOp(cast(Ops, buf.op), dtype, src, None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg) - if buf.forced_realize: ret = UOp(Ops.CONTIGUOUS, dtype, (ret,)) - cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf, ret), buf.st) - if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata - ctx.lazybufs[b] = buf - # things for fuse.py - allbufs[ubuf] = ret - if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[ubuf] = None - for x in src: - if is_scheduled(x.base): children[x.base.buf_uop][ubuf] = None + # everything else is a VIEW of BUFFER (with an optional op) + if buf.is_realized(): + ctx.buf_uops[buf.buffer] = ubuf = UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(ctx.buf_uops)) + op = None + elif buf.op is Ops.ASSIGN: + target, new_val = [to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs] + ctx.assigns.add(ubuf:=target.buf_uop) + op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg) + else: + ctx.buf_uops[buf.buffer] = ubuf = UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(ctx.buf_uops)) + op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs), + None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg) + cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st) + if op is not None: + if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata + ctx.lazybufs[buf.buffer] = buf + for x in op.src: + if is_scheduled(x.base): children[x.base.buf_uop][ubuf] = None + allbufs[ubuf] = ret + if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[ubuf] = None return ret # **** AST graph rewrite diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e6a8deb5e6..73c4d06adc 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -336,6 +336,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False)) def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) + def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,)) @property def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2)