keep realized passthrough [pr] (#8248)

* keep realized passthrough [pr]

* more pruning
This commit is contained in:
qazal
2024-12-15 02:23:15 +02:00
committed by GitHub
parent d78e75f710
commit ace654a7e4

View File

@@ -51,23 +51,19 @@ class ScheduleContext:
def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
# shapeless op is passthrough
if buf.st is None: return buf
# realized is passthrough
if buf.st is None or buf.base.is_realized: return buf
# view is passthrough
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, ctx, cache).view(buf.st)
return ret
# make things that can't be images not images
dtype = buf.buf_uop.dtype.base if buf.is_realized else buf.dtype
dtype = buf.dtype
if isinstance(dtype, ImageDType) and (prod(buf.shape) != prod(dtype.shape) or not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
assert buf.realized is None, "can't fixup allocated buffer"
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
dtype = buf.dtype.base
# base is a VIEW of (BUFFER, (optional) op)
if buf.is_realized:
buf_uop = buf.buf_uop
op = None
# metaops already have a BUFFER uop
elif is_scheduled(buf):
if is_scheduled(buf):
buf_uop = buf.buf_uop
op = buf.src[1].replace(src=tuple(to_uop(x, ctx, cache) for x in buf.src[1].src))
# ASSIGN uses the target buffer, otherwise we create a new buffer
@@ -75,10 +71,10 @@ def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp:
src = tuple(to_uop(x, ctx, cache) for x in buf.srcs)
buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = UOp(buf.op, dtype.base, src, buf.arg)
ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
# track the underlying tensor uop for this op
if op is not None: ctx.tensor_uops[buf_uop] = [buf]
cache[buf] = ret
ctx.tensor_uops[buf_uop] = [buf]
# (early) bufferize
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
return ret
# **** AST graph rewrite