diff --git a/test/test_schedule.py b/test/test_schedule.py index b148f55681..ab8a63b499 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2229,5 +2229,49 @@ class TestTensorUOpSpec(unittest.TestCase): t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views) create_schedule_with_vars(list(t.src)) +class TestBufferUOp(unittest.TestCase): + # BUFFER has a ShapeTracker of shape=(n,) and stride=(1,) + def test_buffer_has_buffer(self): + buf = Tensor.empty(10) + self.assertIsNotNone(buf.lazydata.buffer) + self.assertEqual(buf.lazydata.st, ShapeTracker.from_shape((10,))) + # the device Buffer remains unallocated until it's we run the schedule + self.assertFalse(buf.lazydata.buffer.is_allocated()) + add = buf+1 + sched = add.schedule() + self.assertFalse(buf.lazydata.buffer.is_allocated()) + run_schedule(sched) + self.assertTrue(buf.lazydata.buffer.is_allocated()) + + def test_buffer_has_unique_buffer(self): + buf = Tensor.empty(10) + buf1 = buf.lazydata.buffer + buf2 = buf.lazydata.buffer + self.assertIs(buf1, buf2) + + # we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous + def test_buffer_view_allowed(self): + add = Tensor.empty(1, 1)+Tensor.empty(1, 1) + add.realize() + self.assertIsNotNone(add.lazydata.buffer) + self.assertEqual(add.lazydata.shape, (1, 1)) + + def test_buffer_view_not_allowed(self): + permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1) + merged = graph_rewrite(permuted_view.lazydata, remove_movement_ops) + with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"): + merged.buffer # cannot access Buffer of a non contiguous VIEW + + def test_buffer_only_after_realize(self): + a = Tensor([1])+Tensor([2]) + # accessing realized will return None + self.assertIsNone(a.lazydata.realized) + # accessing Buffer will assert + with self.assertRaisesRegex(AssertionError, "must be BUFFER"): + a.lazydata.buffer # there is no BUFFER on an unrealized ADD + # Buffer only exists once we realize it + a.realize() + self.assertIsNotNone(a.lazydata.buffer) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 97517f562f..84fcbf160f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -110,7 +110,7 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: # shapeless op is passthrough # realized is passthrough # constants are passthrough - if buf.st is None or buf.base.is_realized or buf.base.op is Ops.VIEW or is_constant(buf.base): return buf + if buf.st is None or buf.base.is_realized or is_constant(buf.base): return buf # view is passthrough if buf is not buf.base: cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(buf.st) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5ed433f940..310c7a564f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -516,19 +516,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def buf_uop_view(self) -> UOp: return self.buf_uop.view(unwrap(self.st)) @property def buffer(self) -> Buffer: - if self.base.realized is not None: return self.base.realized - if (ret:=buffers.get(self)) is not None: return ret if self.op is Ops.VIEW: assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous" return self.src[0].buffer assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" + if (cret:=buffers.get(self)) is not None: return cret from tinygrad.device import Buffer buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base) return ret @property def realized(self) -> Optional[Buffer]: if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is Ops.BUFFER: return self.src[0].realized - return buffers.get(self) if self.op is Ops.BUFFER else None + return self.buffer if self.op is Ops.BUFFER else None @property def is_realized(self) -> bool: return self.base.realized is not None