From ace654a7e474bc650eecbe383d6f89155f8d45c9 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 15 Dec 2024 02:23:15 +0200 Subject: [PATCH] keep realized passthrough [pr] (#8248) * keep realized passthrough [pr] * more pruning --- tinygrad/engine/schedule.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 79991e574b..910005d326 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -51,23 +51,19 @@ class ScheduleContext: def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r # shapeless op is passthrough - if buf.st is None: return buf + # realized is passthrough + if buf.st is None or buf.base.is_realized: return buf # view is passthrough if buf is not buf.base: cache[buf] = ret = to_uop(buf.base, ctx, cache).view(buf.st) return ret # make things that can't be images not images - dtype = buf.buf_uop.dtype.base if buf.is_realized else buf.dtype + dtype = buf.dtype if isinstance(dtype, ImageDType) and (prod(buf.shape) != prod(dtype.shape) or not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): - assert buf.realized is None, "can't fixup allocated buffer" if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}") dtype = buf.dtype.base - # base is a VIEW of (BUFFER, (optional) op) - if buf.is_realized: - buf_uop = buf.buf_uop - op = None # metaops already have a BUFFER uop - elif is_scheduled(buf): + if is_scheduled(buf): buf_uop = buf.buf_uop op = buf.src[1].replace(src=tuple(to_uop(x, ctx, cache) for x in buf.src[1].src)) # ASSIGN uses the target buffer, otherwise we create a new buffer @@ -75,10 +71,10 @@ def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp: src = tuple(to_uop(x, ctx, cache) for x in buf.srcs) buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) op = UOp(buf.op, dtype.base, src, buf.arg) - ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st) # track the underlying tensor uop for this op - if op is not None: ctx.tensor_uops[buf_uop] = [buf] - cache[buf] = ret + ctx.tensor_uops[buf_uop] = [buf] + # (early) bufferize + cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st) return ret # **** AST graph rewrite