From ef7ad3f0779fad7fe7a2ea9947c8e8a22dda5ebb Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 5 Feb 2025 09:10:20 +0100 Subject: [PATCH] simpler subbuffer construction + copyin is always base (#8900) * realize copy * cleanup buffer_view * smaller --- tinygrad/engine/schedule.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 3c9fc69b4e..97292b9d74 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -63,6 +63,9 @@ sym = symbolic_simple+PatternMatcher([ # support for using a contiguous permuted view instead of the parent view if one exists (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), (UPat(GroupOp.ALU, name="alu"), replace_contiguous), + # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK + (UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"), + lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None), # remove CONST/BIND/BUFFER from SINK (UPat(Ops.SINK, name="root"), lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) @@ -112,6 +115,7 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src)) # track the buffer uop for the simplified uop buffer_map[buf] = buf_uop + if op.op is Ops.BUFFER_VIEW: buffers[buf_uop] = (x:=op.src[0]).buf_uop.buffer.view(op.size, op.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) # (early) bufferize cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret @@ -132,11 +136,6 @@ def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) # otherwise safety check pads return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx, b, src) -def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None - buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) - return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) - do_realize = PatternMatcher([ # always realize SINK parents (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)), @@ -144,11 +143,8 @@ do_realize = PatternMatcher([ (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize), # realize before expand or unsafe pad ops (UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view), - # realize before COPY or BUFFER_VIEW - (UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), - (UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), - # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK - (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), + # realize before COPY + (UPat(Ops.COPY, src=(UPat(), UPatScheduled())), realize), ]) def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: