mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
keep realized passthrough [pr] (#8248)
* keep realized passthrough [pr] * more pruning
This commit is contained in:
@@ -51,23 +51,19 @@ class ScheduleContext:
|
||||
def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
# shapeless op is passthrough
|
||||
if buf.st is None: return buf
|
||||
# realized is passthrough
|
||||
if buf.st is None or buf.base.is_realized: return buf
|
||||
# view is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, ctx, cache).view(buf.st)
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
dtype = buf.buf_uop.dtype.base if buf.is_realized else buf.dtype
|
||||
dtype = buf.dtype
|
||||
if isinstance(dtype, ImageDType) and (prod(buf.shape) != prod(dtype.shape) or not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
assert buf.realized is None, "can't fixup allocated buffer"
|
||||
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
|
||||
dtype = buf.dtype.base
|
||||
# base is a VIEW of (BUFFER, (optional) op)
|
||||
if buf.is_realized:
|
||||
buf_uop = buf.buf_uop
|
||||
op = None
|
||||
# metaops already have a BUFFER uop
|
||||
elif is_scheduled(buf):
|
||||
if is_scheduled(buf):
|
||||
buf_uop = buf.buf_uop
|
||||
op = buf.src[1].replace(src=tuple(to_uop(x, ctx, cache) for x in buf.src[1].src))
|
||||
# ASSIGN uses the target buffer, otherwise we create a new buffer
|
||||
@@ -75,10 +71,10 @@ def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp:
|
||||
src = tuple(to_uop(x, ctx, cache) for x in buf.srcs)
|
||||
buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = UOp(buf.op, dtype.base, src, buf.arg)
|
||||
ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
|
||||
# track the underlying tensor uop for this op
|
||||
if op is not None: ctx.tensor_uops[buf_uop] = [buf]
|
||||
cache[buf] = ret
|
||||
ctx.tensor_uops[buf_uop] = [buf]
|
||||
# (early) bufferize
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
|
||||
Reference in New Issue
Block a user