big graph init conceptual cleanup [pr] (#8090)

* keep Ops.BUFFER naming consistent [pr]

* big graph init conceptual cleanup [pr]

* make everything pass through

* pylint doesn't complain now
This commit is contained in:
qazal
2024-12-06 20:07:00 +02:00
committed by GitHub
parent 5184410fc3
commit 1ea4dc9565

View File

@@ -51,6 +51,7 @@ def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
# view is passthrough
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st)
return ret
@@ -64,25 +65,29 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
# hack the underlying buffer too
buf.buffer.dtype = dtype
buf.buffer.options = None
if buf.is_realized:
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
buffers[ubuf] = buf.buffer
op = None
elif buf.op is Ops.ASSIGN:
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
ctx.assigns.add(ubuf:=target.base.buf_uop)
op = UOp(Ops.ASSIGN, dtype.base, (ubuf, new_val), buf.arg)
else:
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
buffers[ubuf] = buf.buffer
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
# base is a VIEW of (BUFFER, (optional) op)
match buf.is_realized:
case True:
buf_uop = UOp.new_buffer(buf.device, buf.size, dtype)
op = None
case False:
src = tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs)
match buf.op:
# ASSIGN uses the target buffer
case Ops.ASSIGN: buf_uop = src[0].base.buf_uop
# otherwise we create a new buffer
case _: buf_uop = UOp.new_buffer(buf.device, buf.size, dtype)
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, src, buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.contiguous() if buf.forced_realize else op), buf.st)
# keep track of ops outside the big graph
buffers[buf_uop] = buf.buffer
if op is not None:
buf.buffer.ref(1)
ctx.lazybufs[ubuf] = buf
ctx.allbufs[ubuf] = ret
ctx.lazybufs[buf_uop] = buf
ctx.allbufs[buf_uop] = ret
if op.op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
return ret
# **** AST graph rewrite
@@ -178,7 +183,9 @@ to_si = PatternMatcher([
# ** fusion
lazy = PatternMatcher([
# gather the metadata for this kernel
(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.metadata.add(m) if (m:=ctx.ops_metadata.get(x)) is not None else None),
# don't need contiguous anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
])
@@ -375,6 +382,8 @@ do_realize = PatternMatcher([
(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
# realize before COPY or BUFFER_VIEW
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
# ASSIGN only needs the buffer
(UPat(Ops.ASSIGN, src=(UPat(Ops.VIEW, name="dest"), UPat.var("src")), name="x"), lambda ctx,dest,src,x: x.replace(src=(dest.base.buf_uop, src))),
])
# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs