mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
always give BUFFER uops Buffers [pr] (#8572)
* always give BUFFER uops Buffers [pr] * add test_buffer_only_after_realize
This commit is contained in:
@@ -2229,5 +2229,49 @@ class TestTensorUOpSpec(unittest.TestCase):
|
||||
t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views)
|
||||
create_schedule_with_vars(list(t.src))
|
||||
|
||||
class TestBufferUOp(unittest.TestCase):
|
||||
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
|
||||
def test_buffer_has_buffer(self):
|
||||
buf = Tensor.empty(10)
|
||||
self.assertIsNotNone(buf.lazydata.buffer)
|
||||
self.assertEqual(buf.lazydata.st, ShapeTracker.from_shape((10,)))
|
||||
# the device Buffer remains unallocated until it's we run the schedule
|
||||
self.assertFalse(buf.lazydata.buffer.is_allocated())
|
||||
add = buf+1
|
||||
sched = add.schedule()
|
||||
self.assertFalse(buf.lazydata.buffer.is_allocated())
|
||||
run_schedule(sched)
|
||||
self.assertTrue(buf.lazydata.buffer.is_allocated())
|
||||
|
||||
def test_buffer_has_unique_buffer(self):
|
||||
buf = Tensor.empty(10)
|
||||
buf1 = buf.lazydata.buffer
|
||||
buf2 = buf.lazydata.buffer
|
||||
self.assertIs(buf1, buf2)
|
||||
|
||||
# we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous
|
||||
def test_buffer_view_allowed(self):
|
||||
add = Tensor.empty(1, 1)+Tensor.empty(1, 1)
|
||||
add.realize()
|
||||
self.assertIsNotNone(add.lazydata.buffer)
|
||||
self.assertEqual(add.lazydata.shape, (1, 1))
|
||||
|
||||
def test_buffer_view_not_allowed(self):
|
||||
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
|
||||
merged = graph_rewrite(permuted_view.lazydata, remove_movement_ops)
|
||||
with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"):
|
||||
merged.buffer # cannot access Buffer of a non contiguous VIEW
|
||||
|
||||
def test_buffer_only_after_realize(self):
|
||||
a = Tensor([1])+Tensor([2])
|
||||
# accessing realized will return None
|
||||
self.assertIsNone(a.lazydata.realized)
|
||||
# accessing Buffer will assert
|
||||
with self.assertRaisesRegex(AssertionError, "must be BUFFER"):
|
||||
a.lazydata.buffer # there is no BUFFER on an unrealized ADD
|
||||
# Buffer only exists once we realize it
|
||||
a.realize()
|
||||
self.assertIsNotNone(a.lazydata.buffer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -110,7 +110,7 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
# shapeless op is passthrough
|
||||
# realized is passthrough
|
||||
# constants are passthrough
|
||||
if buf.st is None or buf.base.is_realized or buf.base.op is Ops.VIEW or is_constant(buf.base): return buf
|
||||
if buf.st is None or buf.base.is_realized or is_constant(buf.base): return buf
|
||||
# view is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(buf.st)
|
||||
|
||||
@@ -516,19 +516,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def buf_uop_view(self) -> UOp: return self.buf_uop.view(unwrap(self.st))
|
||||
@property
|
||||
def buffer(self) -> Buffer:
|
||||
if self.base.realized is not None: return self.base.realized
|
||||
if (ret:=buffers.get(self)) is not None: return ret
|
||||
if self.op is Ops.VIEW:
|
||||
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
||||
return self.src[0].buffer
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
if (cret:=buffers.get(self)) is not None: return cret
|
||||
from tinygrad.device import Buffer
|
||||
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
|
||||
return ret
|
||||
@property
|
||||
def realized(self) -> Optional[Buffer]:
|
||||
if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is Ops.BUFFER: return self.src[0].realized
|
||||
return buffers.get(self) if self.op is Ops.BUFFER else None
|
||||
return self.buffer if self.op is Ops.BUFFER else None
|
||||
@property
|
||||
def is_realized(self) -> bool: return self.base.realized is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user