mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
simpler subbuffer construction + copyin is always base (#8900)
* realize copy * cleanup buffer_view * smaller
This commit is contained in:
@@ -63,6 +63,9 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
|
||||
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
|
||||
# remove CONST/BIND/BUFFER from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
|
||||
@@ -112,6 +115,7 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
|
||||
# track the buffer uop for the simplified uop
|
||||
buffer_map[buf] = buf_uop
|
||||
if op.op is Ops.BUFFER_VIEW: buffers[buf_uop] = (x:=op.src[0]).buf_uop.buffer.view(op.size, op.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
# (early) bufferize
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
|
||||
return ret
|
||||
@@ -132,11 +136,6 @@ def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx, b, src)
|
||||
|
||||
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
||||
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
|
||||
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
|
||||
@@ -144,11 +143,8 @@ do_realize = PatternMatcher([
|
||||
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
||||
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
|
||||
# realize before COPY
|
||||
(UPat(Ops.COPY, src=(UPat(), UPatScheduled())), realize),
|
||||
])
|
||||
|
||||
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
|
||||
Reference in New Issue
Block a user