consts do not realize

This commit is contained in:
George Hotz
2024-12-17 08:53:53 -08:00
parent 4764a4c172
commit 0794af97db

View File

@@ -18,6 +18,51 @@ class TestTensorUopRepresentation(unittest.TestCase):
print(c.lazydata)
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
def test_consts_do_not_realize(self):
a = Tensor(1)
print(a.lazydata)
pre_realize = a.lazydata
a.realize()
assert a.lazydata is pre_realize
def test_viewed_consts_do_not_realize(self):
a = Tensor.ones(10, 10)
print(a.lazydata)
pre_realize = a.lazydata
a.realize()
assert a.lazydata is pre_realize
# currently, CONSTs have a "fake" BUFFER. this should be fixed
# current:
# UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=(
# UOp(Ops.RESHAPE, dtypes.float, arg=(1, 1), src=(
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(-1, 'METAL', 1), src=()),
# UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),)),)),))
# expected:
# UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=(
# UOp(Ops.RESHAPE, dtypes.float, arg=(1, 1), src=(
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.CONST, dtypes.float, arg=1.0, src=(
# UOp(Ops.DEVICE, dtypes.void, arg="METAL", src=()),)),)),))
@unittest.expectedFailure
def test_consts_dont_have_buffers(self):
a = Tensor.ones(10, 10)
print(a.lazydata)
buffers_in_parents = [x.op for x in a.lazydata.toposort if x.op is Ops.BUFFER]
self.assertEqual(len(buffers_in_parents), 0)
# currently, COPY has an extra BUFFER on the output
# current:
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, 'TEST', 3), src=()),
# UOp(Ops.COPY, dtypes.float, arg=('TEST', False), src=(
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, 'METAL', 3), src=()),)),)),))
# expected:
# UOp(Ops.COPY, dtypes.float, arg=('TEST', False), src=(
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, 'METAL', 3), src=()),))
@unittest.expectedFailure
def test_copyin(self):
a = Tensor([1.,2,3]).realize()