From 801e199196c19f1f1adee43cff616adce45d0580 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 17 Dec 2024 16:47:51 -0800 Subject: [PATCH] change buffer to not be pointer [pr] (#8302) --- test/test_schedule.py | 26 ++++++++++----------- test/unit/test_tensor_uop_representation.py | 8 +++---- tinygrad/engine/schedule.py | 2 +- tinygrad/ops.py | 4 ++-- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 3b05fe8407..9325bee2e5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1761,7 +1761,7 @@ class TestSwizzle(unittest.TestCase): def test_permute_rewrite(self): sink = UOp(Ops.STORE, dtypes.void, arg=None, src=( - x1:=UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()), + x1:=UOp(Ops.BUFFER, dtypes.float, arg=(1, ('METAL', 16384, dtypes.float)), src=()), x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=( x1, @@ -1773,15 +1773,15 @@ class TestSwizzle(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 16384, dtypes.float)), src=()), x2,)),)), UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 256, dtypes.float)), src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)), UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 16, dtypes.float)), src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( x11,)),)),)),)),)) @@ -1793,7 +1793,7 @@ class TestSwizzle(unittest.TestCase): # fuse (relu bw, conv2d, conv2d bw, relu) sink = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 128, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 128, dtypes.float)), src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( @@ -1808,16 +1808,16 @@ class TestSwizzle(unittest.TestCase): UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(9, ('METAL', 96, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(9, ('METAL', 96, dtypes.float)), src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(48, 0, 0, 4, 1, 16, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(16, ('METAL', 432, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(16, ('METAL', 432, dtypes.float)), src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(0, 0, 27, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), x6,)),)),)), UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(18, ('METAL', 128, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(18, ('METAL', 128, dtypes.float)), src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 3, 2, 3), strides=(64, 4, 2, 0, 1, 0), offset=0, mask=((0, 2), (0, 16), (0, 2), (0, 1), (0, 2), (0, 1)), contiguous=False), View(shape=(1, 2, 1, 16, 3, 2, 3, 2), strides=(0, 576, 0, 36, 12, 6, 2, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)) ret = swizzle_rewrite(sink) self.assertEqual(swizzle_cnt(ret), 0) @@ -1826,7 +1826,7 @@ class TestSwizzle(unittest.TestCase): def test_swizzle_failure_permute(self): sink = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(20, ('METAL', 65, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(20, ('METAL', 65, dtypes.float)), src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( @@ -1834,7 +1834,7 @@ class TestSwizzle(unittest.TestCase): x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 2925, dtypes.float)), src=()), x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(Ops.WHERE, dtypes.float, arg=None, src=( x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( @@ -1854,12 +1854,12 @@ class TestSwizzle(unittest.TestCase): x15,)), UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 2925, dtypes.float)), src=()), x10,)), UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(4, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(4, ('METAL', 2925, dtypes.float)), src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)) ret = swizzle_rewrite(sink) self.assertEqual(swizzle_cnt(ret), 0) @@ -1952,7 +1952,7 @@ class TestBigGraph(unittest.TestCase): def test_sink_childless_const_alt(self): x = UOp.const(dtypes.int, 0) - y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(())) + y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int, (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(())) big_graph = big_graph_rewrite(UOp.sink(x, y), ctx:=ScheduleContext()) self.assertIs(big_graph, UOp(Ops.NOOP)) self.assertEqual(len(ctx.realizes), 0) diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index 1446e331ce..9d71eac66f 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -37,7 +37,7 @@ class TestTensorUopRepresentation(unittest.TestCase): # UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=( # UOp(Ops.RESHAPE, dtypes.float, arg=(1, 1), src=( # UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=( - # UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(-1, 'METAL', 1), src=()), + # UOp(Ops.BUFFER, dtypes.float, arg=(-1, 'METAL', 1), src=()), # UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),)),)),)) # expected: # UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=( @@ -55,14 +55,14 @@ class TestTensorUopRepresentation(unittest.TestCase): # currently, COPY has an extra BUFFER on the output # current: # UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=( - # UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, 'TEST', 3), src=()), + # UOp(Ops.BUFFER, dtypes.float, arg=(2, 'TEST', 3), src=()), # UOp(Ops.COPY, dtypes.float, arg=('TEST', False), src=( # UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=( - # UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, 'METAL', 3), src=()),)),)),)) + # UOp(Ops.BUFFER, dtypes.float, arg=(1, 'METAL', 3), src=()),)),)),)) # expected: # UOp(Ops.COPY, dtypes.float, arg=('TEST', False), src=( # UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=( - # UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, 'METAL', 3), src=()),)) + # UOp(Ops.BUFFER, dtypes.float, arg=(1, 'METAL', 3), src=()),)) @unittest.expectedFailure def test_copyin(self): a = Tensor([1.,2,3]).realize() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index a640f66dd0..a142a83152 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -151,7 +151,7 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]: def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: ctx.bufs.append(x) - return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1) + return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(), (), len(ctx.bufs)-1) append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)]) def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index be97248936..71fce0eb5d 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -438,7 +438,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): from tinygrad.shape.shapetracker import ShapeTracker if op is Ops.CONST: # NOTE: we embed device on CONST with a fake BUFFER uop - fake = UOp(Ops.BUFFER, dtype.ptr(), (UOp(Ops.DEVICE, arg=device),), (-1, 1)) + fake = UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (-1, 1)) # NOTE: BIND stays BIND, UOp.const unbinds here const_uop = arg if isinstance(arg, UOp) else UOp.const(dtype, unwrap(arg)) return UOp(Ops.VIEW, dtype, (fake, const_uop), ShapeTracker.from_shape(())).reshape((1,)*len(shape)).expand(shape) @@ -505,7 +505,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): buffer_num = itertools.count(0) @staticmethod def new_buffer(device:str, size:int, dtype:DType) -> UOp: - return UOp(Ops.BUFFER, dtype.ptr(), (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size)) + return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size)) @property def device(self) -> str: return unwrap(self._device) @functools.cached_property