mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cleanup around disk buffer [pr] (#14428)
style change, prep for refactor
This commit is contained in:
@@ -75,7 +75,7 @@ class BufferCopy(Runner):
|
||||
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
|
||||
if disk_supports_fast_copyout and hasattr(dest.allocator, 'copy_from_disk') and src.nbytes >= 4096:
|
||||
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
||||
elif (src.device.startswith("DISK") or src.device.startswith("TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
|
||||
elif isinstance(src.device, str) and src.device.startswith(("DISK", "TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
|
||||
else:
|
||||
|
||||
@@ -252,22 +252,22 @@ pm_remove_bufferize = PatternMatcher([
|
||||
])
|
||||
|
||||
def late_buffer_view(t:UOp, b:UOp):
|
||||
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
|
||||
shape = b.shape
|
||||
size = prod(shape)
|
||||
if not (isinstance(b.device, str) and b.device.startswith(("DISK", "TINYFS"))): return b
|
||||
shape = b.shape
|
||||
size = prod(shape)
|
||||
|
||||
# walk up for the INDEX
|
||||
x = t
|
||||
while not any(u.op is Ops.INDEX for u in x.src):
|
||||
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
|
||||
x = x.src[0]
|
||||
x = next(u for u in x.src if u.op is Ops.INDEX)
|
||||
# walk up for the INDEX
|
||||
x = t
|
||||
while not any(u.op is Ops.INDEX for u in x.src):
|
||||
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
|
||||
x = x.src[0]
|
||||
x = next(u for u in x.src if u.op is Ops.INDEX)
|
||||
|
||||
if len(shape) == 0: offset = x.src[1].arg
|
||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||
if len(shape) == 0: offset = x.src[1].arg
|
||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
|
||||
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
|
||||
return b
|
||||
to_bufferview = PatternMatcher([
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(Ops.BUFFERIZE, allow_any_len=True, name="b"), late_buffer_view),
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
|
||||
|
||||
Reference in New Issue
Block a user