From a979fafae539d169889955bc76d1041e181f91ec Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 29 Jan 2026 16:18:44 -0500 Subject: [PATCH] cleanup around disk buffer [pr] (#14428) style change, prep for refactor --- tinygrad/engine/realize.py | 2 +- tinygrad/schedule/rangeify.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index b5cd805604..b03dea7a55 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -75,7 +75,7 @@ class BufferCopy(Runner): getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk if disk_supports_fast_copyout and hasattr(dest.allocator, 'copy_from_disk') and src.nbytes >= 4096: dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes) - elif (src.device.startswith("DISK") or src.device.startswith("TINYFS")) and hasattr(dest.allocator, '_as_buffer'): + elif isinstance(src.device, str) and src.device.startswith(("DISK", "TINYFS")) and hasattr(dest.allocator, '_as_buffer'): # fast(ish) path, uses readinto in diskbuffers src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf) else: diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index e4c38550b3..bc4ee0ab1b 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -252,22 +252,22 @@ pm_remove_bufferize = PatternMatcher([ ]) def late_buffer_view(t:UOp, b:UOp): - if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")): - shape = b.shape - size = prod(shape) + if not (isinstance(b.device, str) and b.device.startswith(("DISK", "TINYFS"))): return b + shape = b.shape + size = prod(shape) - # walk up for the INDEX - x = t - while not any(u.op is Ops.INDEX for u in x.src): - assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise" - x = x.src[0] - x = next(u for u in x.src if u.op is Ops.INDEX) + # walk up for the INDEX + x = t + while not any(u.op is Ops.INDEX for u in x.src): + assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise" + x = x.src[0] + x = next(u for u in x.src if u.op is Ops.INDEX) - if len(shape) == 0: offset = x.src[1].arg - else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0) + if len(shape) == 0: offset = x.src[1].arg + else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0) + + return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:]) - return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:]) - return b to_bufferview = PatternMatcher([ (UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(Ops.BUFFERIZE, allow_any_len=True, name="b"), late_buffer_view), (UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),