mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
feat: late to_bufferview (#12271)
This commit is contained in:
@@ -53,12 +53,6 @@ earliest_rewrites = double_reshape+PatternMatcher([
|
||||
# copy only to different device
|
||||
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP, tag=copy.tag) if x.device == copy.device else None),
|
||||
|
||||
# handle disk
|
||||
# TODO: this doesn't need to use st.views
|
||||
(UPat.var("x").f((Ops.BITCAST, Ops.CONTIGUOUS), name="t"),
|
||||
lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset), tag=t.tag).reshape(t.shape) if isinstance(x.device, str) \
|
||||
and x.device.startswith("DISK") else None),
|
||||
|
||||
# contiguous/buffer/copy/assign is already contiguous
|
||||
#(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
|
||||
])
|
||||
@@ -407,6 +401,25 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
||||
.f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize),
|
||||
])
|
||||
|
||||
def late_buffer_view(x, t, b):
|
||||
if isinstance(t.device, str) and t.device.startswith("DISK"):
|
||||
rngs = b.src[1:]
|
||||
size = prod(shape := [int(r.vmax+1) for r in rngs])
|
||||
if len(shape) == 0:
|
||||
offset = x.src[1].arg
|
||||
else:
|
||||
idxs = x.src[1:]
|
||||
offset = sum(idx.vmin for idx in idxs)
|
||||
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
|
||||
return b
|
||||
to_bufferview = PatternMatcher([
|
||||
(UPat(Ops.INDEX, name="x").f((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(Ops.BUFFERIZE, allow_any_len=True, name="b"), late_buffer_view),
|
||||
(UPat(Ops.INDEX, name="x").f((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(GroupOp.All).f(Ops.BUFFERIZE, allow_any_len=True, name="b"),
|
||||
late_buffer_view),
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 4. put in buffers for bufferize
|
||||
# TODO: should BUFFERIZE look a lot more like STORE
|
||||
@@ -454,7 +467,7 @@ def bufferize_to_store(x:UOp):
|
||||
# TODO: how is this unified?
|
||||
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
|
||||
|
||||
pm_add_buffers = pm_mops+PatternMatcher([
|
||||
pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
|
||||
# move RESHAPEs through MSELECT/MSTACK
|
||||
|
||||
Reference in New Issue
Block a user