mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 16:27:59 -05:00
big graph init conceptual cleanup [pr] (#8090)
* keep Ops.BUFFER naming consistent [pr] * big graph init conceptual cleanup [pr] * make everything pass through * pylint doesn't complain now
This commit is contained in:
@@ -51,6 +51,7 @@ def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2
|
||||
|
||||
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
# view is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st)
|
||||
return ret
|
||||
@@ -64,25 +65,29 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
|
||||
# hack the underlying buffer too
|
||||
buf.buffer.dtype = dtype
|
||||
buf.buffer.options = None
|
||||
if buf.is_realized:
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
buffers[ubuf] = buf.buffer
|
||||
op = None
|
||||
elif buf.op is Ops.ASSIGN:
|
||||
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
|
||||
ctx.assigns.add(ubuf:=target.base.buf_uop)
|
||||
op = UOp(Ops.ASSIGN, dtype.base, (ubuf, new_val), buf.arg)
|
||||
else:
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
buffers[ubuf] = buf.buffer
|
||||
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
||||
# base is a VIEW of (BUFFER, (optional) op)
|
||||
match buf.is_realized:
|
||||
case True:
|
||||
buf_uop = UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = None
|
||||
case False:
|
||||
src = tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs)
|
||||
match buf.op:
|
||||
# ASSIGN uses the target buffer
|
||||
case Ops.ASSIGN: buf_uop = src[0].base.buf_uop
|
||||
# otherwise we create a new buffer
|
||||
case _: buf_uop = UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, src, buf.arg)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.contiguous() if buf.forced_realize else op), buf.st)
|
||||
# keep track of ops outside the big graph
|
||||
buffers[buf_uop] = buf.buffer
|
||||
if op is not None:
|
||||
buf.buffer.ref(1)
|
||||
ctx.lazybufs[ubuf] = buf
|
||||
ctx.allbufs[ubuf] = ret
|
||||
ctx.lazybufs[buf_uop] = buf
|
||||
ctx.allbufs[buf_uop] = ret
|
||||
if op.op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
@@ -178,7 +183,9 @@ to_si = PatternMatcher([
|
||||
# ** fusion
|
||||
|
||||
lazy = PatternMatcher([
|
||||
# gather the metadata for this kernel
|
||||
(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.metadata.add(m) if (m:=ctx.ops_metadata.get(x)) is not None else None),
|
||||
# don't need contiguous anymore
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
])
|
||||
|
||||
@@ -375,6 +382,8 @@ do_realize = PatternMatcher([
|
||||
(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
# ASSIGN only needs the buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.VIEW, name="dest"), UPat.var("src")), name="x"), lambda ctx,dest,src,x: x.replace(src=(dest.base.buf_uop, src))),
|
||||
])
|
||||
|
||||
# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs
|
||||
|
||||
Reference in New Issue
Block a user