always give BUFFER uops Buffers [pr] (#8572)

* always give BUFFER uops Buffers [pr]

* add test_buffer_only_after_realize
This commit is contained in:
qazal
2025-01-11 16:17:09 -05:00
committed by GitHub
parent 98c9e23560
commit 87cbff3ac0
3 changed files with 47 additions and 4 deletions

View File

@@ -2229,5 +2229,49 @@ class TestTensorUOpSpec(unittest.TestCase):
t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views)
create_schedule_with_vars(list(t.src))
class TestBufferUOp(unittest.TestCase):
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
def test_buffer_has_buffer(self):
buf = Tensor.empty(10)
self.assertIsNotNone(buf.lazydata.buffer)
self.assertEqual(buf.lazydata.st, ShapeTracker.from_shape((10,)))
# the device Buffer remains unallocated until it's we run the schedule
self.assertFalse(buf.lazydata.buffer.is_allocated())
add = buf+1
sched = add.schedule()
self.assertFalse(buf.lazydata.buffer.is_allocated())
run_schedule(sched)
self.assertTrue(buf.lazydata.buffer.is_allocated())
def test_buffer_has_unique_buffer(self):
buf = Tensor.empty(10)
buf1 = buf.lazydata.buffer
buf2 = buf.lazydata.buffer
self.assertIs(buf1, buf2)
# we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous
def test_buffer_view_allowed(self):
add = Tensor.empty(1, 1)+Tensor.empty(1, 1)
add.realize()
self.assertIsNotNone(add.lazydata.buffer)
self.assertEqual(add.lazydata.shape, (1, 1))
def test_buffer_view_not_allowed(self):
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
merged = graph_rewrite(permuted_view.lazydata, remove_movement_ops)
with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"):
merged.buffer # cannot access Buffer of a non contiguous VIEW
def test_buffer_only_after_realize(self):
a = Tensor([1])+Tensor([2])
# accessing realized will return None
self.assertIsNone(a.lazydata.realized)
# accessing Buffer will assert
with self.assertRaisesRegex(AssertionError, "must be BUFFER"):
a.lazydata.buffer # there is no BUFFER on an unrealized ADD
# Buffer only exists once we realize it
a.realize()
self.assertIsNotNone(a.lazydata.buffer)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -110,7 +110,7 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
# shapeless op is passthrough
# realized is passthrough
# constants are passthrough
if buf.st is None or buf.base.is_realized or buf.base.op is Ops.VIEW or is_constant(buf.base): return buf
if buf.st is None or buf.base.is_realized or is_constant(buf.base): return buf
# view is passthrough
if buf is not buf.base:
cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(buf.st)

View File

@@ -516,19 +516,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def buf_uop_view(self) -> UOp: return self.buf_uop.view(unwrap(self.st))
@property
def buffer(self) -> Buffer:
if self.base.realized is not None: return self.base.realized
if (ret:=buffers.get(self)) is not None: return ret
if self.op is Ops.VIEW:
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
return self.src[0].buffer
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
if (cret:=buffers.get(self)) is not None: return cret
from tinygrad.device import Buffer
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
return ret
@property
def realized(self) -> Optional[Buffer]:
if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is Ops.BUFFER: return self.src[0].realized
return buffers.get(self) if self.op is Ops.BUFFER else None
return self.buffer if self.op is Ops.BUFFER else None
@property
def is_realized(self) -> bool: return self.base.realized is not None