simpler subbuffer construction + copyin is always base (#8900)

* realize copy

* cleanup buffer_view

* smaller
This commit is contained in:
qazal
2025-02-05 09:10:20 +01:00
committed by GitHub
parent 6f0cc2e9c5
commit ef7ad3f077

View File

@@ -63,6 +63,9 @@ sym = symbolic_simple+PatternMatcher([
# support for using a contiguous permuted view instead of the parent view if one exists
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
# remove CONST/BIND/BUFFER from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
@@ -112,6 +115,7 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
# track the buffer uop for the simplified uop
buffer_map[buf] = buf_uop
if op.op is Ops.BUFFER_VIEW: buffers[buf_uop] = (x:=op.src[0]).buf_uop.buffer.view(op.size, op.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
# (early) bufferize
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
return ret
@@ -132,11 +136,6 @@ def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx, b, src)
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
@@ -144,11 +143,8 @@ do_realize = PatternMatcher([
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
# realize before COPY or BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
# realize before COPY
(UPat(Ops.COPY, src=(UPat(), UPatScheduled())), realize),
])
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: