cleanup around disk buffer [pr] (#14428)

style change, prep for refactor
This commit is contained in:
chenyu
2026-01-29 16:18:44 -05:00
committed by GitHub
parent dc977a03b0
commit a979fafae5
2 changed files with 14 additions and 14 deletions

View File

@@ -75,7 +75,7 @@ class BufferCopy(Runner):
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
if disk_supports_fast_copyout and hasattr(dest.allocator, 'copy_from_disk') and src.nbytes >= 4096:
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
elif (src.device.startswith("DISK") or src.device.startswith("TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
elif isinstance(src.device, str) and src.device.startswith(("DISK", "TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
else:

View File

@@ -252,22 +252,22 @@ pm_remove_bufferize = PatternMatcher([
])
def late_buffer_view(t:UOp, b:UOp):
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
shape = b.shape
size = prod(shape)
if not (isinstance(b.device, str) and b.device.startswith(("DISK", "TINYFS"))): return b
shape = b.shape
size = prod(shape)
# walk up for the INDEX
x = t
while not any(u.op is Ops.INDEX for u in x.src):
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
x = x.src[0]
x = next(u for u in x.src if u.op is Ops.INDEX)
# walk up for the INDEX
x = t
while not any(u.op is Ops.INDEX for u in x.src):
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
x = x.src[0]
x = next(u for u in x.src if u.op is Ops.INDEX)
if len(shape) == 0: offset = x.src[1].arg
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
if len(shape) == 0: offset = x.src[1].arg
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
return b
to_bufferview = PatternMatcher([
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(Ops.BUFFERIZE, allow_any_len=True, name="b"), late_buffer_view),
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),